This final chapter (other than the conclusion and the online chapters) is going to look a bit different. It contains far more code and far less prose than the previous chapters. We will introduce new Python keywords and libraries without discussing them. This chapter is meant to be the start of a significant research project for you. You see, we are going to implement many of the key pieces of the fastai and PyTorch APIs from scratch, building on nothing other than the components that we developed in <>! The key goal here is to end up with your own Learner class, and some callbacks—enough to be able to train a model on Imagenette, including examples of each of the key techniques we've studied. On the way to building Learner, we will create our own version of Module, Parameter, and parallel DataLoader so you have a very good idea of what those PyTorch classes do.</p>

The end-of-chapter questionnaire is particularly important for this chapter. This is where we will be pointing you in the many interesting directions that you could take, using this chapter as your starting point. We suggest that you follow along with this chapter on your computer, and do lots of experiments, web searches, and whatever else you need to understand what's going on. You've built up the skills and expertise to do this in the rest of this book, so we think you are going to do great!

</div> </div> </div>

Let's begin by gathering (manually) some data.

Data

Have a look at the source to untar_data to see how it works. We'll use it here to access the 160-pixel version of Imagenette for use in this chapter:

untar_data??
Signature:
untar_data(
    url: 'str',
    archive: 'Path' = None,
    data: 'Path' = None,
    c_key: 'str' = 'data',
    force_download: 'bool' = False,
    base: 'str' = '~/.fastai',
) -> 'Path'
Source:   
def untar_data(
    url:str, # File to download
    archive:Path=None, # Optional override for `Config`'s `archive` key
    data:Path=None, # Optional override for `Config`'s `data` key
    c_key:str='data', # Key in `Config` where to extract file
    force_download:bool=False, # Setting to `True` will overwrite any existing copy of data
    base:str='~/.fastai' # Directory containing config file and base of relative paths
) -> Path: # Path to extracted file(s)
    "Download `url` using `FastDownload.get`"
    d = FastDownload(fastai_cfg(), module=fastai.data, archive=archive, data=data, base=base)
    return d.get(url, force=force_download, extract_key=c_key)
File:      ~/mambaforge/lib/python3.9/site-packages/fastai/data/external.py
Type:      function
FastDownload?
Object `FastDownload` not found.

Where does 'FastDownload' come from???

!ls -d /root/mambaforge/lib/python3.9/site-packages/fast* | egrep -v '\-info'
/root/mambaforge/lib/python3.9/site-packages/fastai
/root/mambaforge/lib/python3.9/site-packages/fastbook
/root/mambaforge/lib/python3.9/site-packages/fastcore
/root/mambaforge/lib/python3.9/site-packages/fastdownload
/root/mambaforge/lib/python3.9/site-packages/fastjsonschema
/root/mambaforge/lib/python3.9/site-packages/fastprogress
/root/mambaforge/lib/python3.9/site-packages/fastrelease

As seen above, fastai or Jeremy has many packages installed with prefixed 'fast'. How should we look into such source code???

import fastdownload
fastdownload.FastDownload??
Init signature:
fastdownload.FastDownload(
    cfg=None,
    base='~/.fastdownload',
    archive=None,
    data=None,
    module=None,
)
Docstring:      <no docstring>
Source:        
class FastDownload:
    def __init__(self, cfg=None, base='~/.fastdownload', archive=None, data=None, module=None):
        base = Path(base).expanduser().absolute()
        default = {'data':(data or 'data'), 'archive':(archive or 'archive')}
        self.cfg = Config(base, 'config.ini', create=default) if cfg is None else cfg
        self.module = checks_module(module)
        if data is not None: self.cfg['data'] = data
        if archive is not None: self.cfg['archive'] = archive

    def arch_path(self):
        "Path to archives"
        return self.cfg.path('archive')

    def data_path(self, extract_key='data', arch=None):
        "Path to extracted data"
        path = self.cfg.path(extract_key)
        return path if arch is None else path/remove_suffix(arch.stem, '.tar')

    def check(self, url, fpath):
        "Check whether size and hash of `fpath` matches stored data for `url` or data is missing"
        checks = read_checks(self.module).get(url)
        return not checks or path_stats(fpath)==checks

    def download(self, url, force=False):
        "Download `url` to archive path, unless exists and `self.check` fails and not `force`"
        self.arch_path().mkdir(exist_ok=True, parents=True)
        return download_and_check(url, urldest(url, self.arch_path()), self.module, force)

    def rm(self, url, rm_arch=True, rm_data=True, extract_key='data'):
        "Delete downloaded archive and extracted data for `url`"
        arch = urldest(url, self.arch_path())
        if rm_arch: arch.delete()
        if rm_data: self.data_path(extract_key, arch).delete()

    def update(self, url):
        "Store the hash and size in `download_checks.py`"
        update_checks(urldest(url, self.arch_path()), url, self.module)

    def extract(self, url, extract_key='data', force=False):
        "Extract archive already downloaded from `url`, overwriting existing if `force`"
        arch = urldest(url, self.arch_path())
        if not arch.exists(): raise Exception(f'{arch} does not exist')
        dest = self.data_path(extract_key)
        dest.mkdir(exist_ok=True, parents=True)
        return untar_dir(arch, dest, rename=True, overwrite=force)

    def get(self, url, extract_key='data', force=False):
        "Download and extract `url`, overwriting existing if `force`"
        if not force:
            data = self.data_path(extract_key, urldest(url, self.arch_path()))
            if data.exists(): return data
        self.download(url, force=force)
        return self.extract(url, extract_key=extract_key, force=force)
File:           ~/mambaforge/lib/python3.9/site-packages/fastdownload/core.py
Type:           type
Subclasses:     

It seems that you need to 'import' here too to access the source of the indirectly called functions/classes.

URLs??
Init signature: URLs()
Source:        
class URLs():
    "Global constants for dataset and model URLs."
    LOCAL_PATH = Path.cwd()
    MDL = 'http://files.fast.ai/models/'
    GOOGLE = 'https://storage.googleapis.com/'
    S3  = 'https://s3.amazonaws.com/fast-ai-'
    URL = f'{S3}sample/'

    S3_IMAGE    = f'{S3}imageclas/'
    S3_IMAGELOC = f'{S3}imagelocal/'
    S3_AUDI     = f'{S3}audio/'
    S3_NLP      = f'{S3}nlp/'
    S3_COCO     = f'{S3}coco/'
    S3_MODEL    = f'{S3}modelzoo/'

    # main datasets
    ADULT_SAMPLE        = f'{URL}adult_sample.tgz'
    BIWI_SAMPLE         = f'{URL}biwi_sample.tgz'
    CIFAR               = f'{URL}cifar10.tgz'
    COCO_SAMPLE         = f'{S3_COCO}coco_sample.tgz'
    COCO_TINY           = f'{S3_COCO}coco_tiny.tgz'
    HUMAN_NUMBERS       = f'{URL}human_numbers.tgz'
    IMDB                = f'{S3_NLP}imdb.tgz'
    IMDB_SAMPLE         = f'{URL}imdb_sample.tgz'
    ML_SAMPLE           = f'{URL}movie_lens_sample.tgz'
    ML_100k             = 'https://files.grouplens.org/datasets/movielens/ml-100k.zip'
    MNIST_SAMPLE        = f'{URL}mnist_sample.tgz'
    MNIST_TINY          = f'{URL}mnist_tiny.tgz'
    MNIST_VAR_SIZE_TINY = f'{S3_IMAGE}mnist_var_size_tiny.tgz'
    PLANET_SAMPLE       = f'{URL}planet_sample.tgz'
    PLANET_TINY         = f'{URL}planet_tiny.tgz'
    IMAGENETTE          = f'{S3_IMAGE}imagenette2.tgz'
    IMAGENETTE_160      = f'{S3_IMAGE}imagenette2-160.tgz'
    IMAGENETTE_320      = f'{S3_IMAGE}imagenette2-320.tgz'
    IMAGEWOOF           = f'{S3_IMAGE}imagewoof2.tgz'
    IMAGEWOOF_160       = f'{S3_IMAGE}imagewoof2-160.tgz'
    IMAGEWOOF_320       = f'{S3_IMAGE}imagewoof2-320.tgz'
    IMAGEWANG           = f'{S3_IMAGE}imagewang.tgz'
    IMAGEWANG_160       = f'{S3_IMAGE}imagewang-160.tgz'
    IMAGEWANG_320       = f'{S3_IMAGE}imagewang-320.tgz'

    # kaggle competitions download dogs-vs-cats -p {DOGS.absolute()}
    DOGS = f'{URL}dogscats.tgz'

    # image classification datasets
    CALTECH_101  = f'{S3_IMAGE}caltech_101.tgz'
    CARS         = f'{S3_IMAGE}stanford-cars.tgz'
    CIFAR_100    = f'{S3_IMAGE}cifar100.tgz'
    CUB_200_2011 = f'{S3_IMAGE}CUB_200_2011.tgz'
    FLOWERS      = f'{S3_IMAGE}oxford-102-flowers.tgz'
    FOOD         = f'{S3_IMAGE}food-101.tgz'
    MNIST        = f'{S3_IMAGE}mnist_png.tgz'
    PETS         = f'{S3_IMAGE}oxford-iiit-pet.tgz'

    # NLP datasets
    AG_NEWS                 = f'{S3_NLP}ag_news_csv.tgz'
    AMAZON_REVIEWS          = f'{S3_NLP}amazon_review_full_csv.tgz'
    AMAZON_REVIEWS_POLARITY = f'{S3_NLP}amazon_review_polarity_csv.tgz'
    DBPEDIA                 = f'{S3_NLP}dbpedia_csv.tgz'
    MT_ENG_FRA              = f'{S3_NLP}giga-fren.tgz'
    SOGOU_NEWS              = f'{S3_NLP}sogou_news_csv.tgz'
    WIKITEXT                = f'{S3_NLP}wikitext-103.tgz'
    WIKITEXT_TINY           = f'{S3_NLP}wikitext-2.tgz'
    YAHOO_ANSWERS           = f'{S3_NLP}yahoo_answers_csv.tgz'
    YELP_REVIEWS            = f'{S3_NLP}yelp_review_full_csv.tgz'
    YELP_REVIEWS_POLARITY   = f'{S3_NLP}yelp_review_polarity_csv.tgz'

    # Image localization datasets
    BIWI_HEAD_POSE     = f"{S3_IMAGELOC}biwi_head_pose.tgz"
    CAMVID             = f'{S3_IMAGELOC}camvid.tgz'
    CAMVID_TINY        = f'{URL}camvid_tiny.tgz'
    LSUN_BEDROOMS      = f'{S3_IMAGE}bedroom.tgz'
    PASCAL_2007        = f'{S3_IMAGELOC}pascal_2007.tgz'
    PASCAL_2012        = f'{S3_IMAGELOC}pascal_2012.tgz'

    # Audio classification datasets
    MACAQUES           = f'{GOOGLE}ml-animal-sounds-datasets/macaques.zip'
    ZEBRA_FINCH        = f'{GOOGLE}ml-animal-sounds-datasets/zebra_finch.zip'

    # Medical Imaging datasets
    #SKIN_LESION        = f'{S3_IMAGELOC}skin_lesion.tgz'
    SIIM_SMALL         = f'{S3_IMAGELOC}siim_small.tgz'
    TCGA_SMALL         = f'{S3_IMAGELOC}tcga_small.tgz'

    #Pretrained models
    OPENAI_TRANSFORMER = f'{S3_MODEL}transformer.tgz'
    WT103_FWD          = f'{S3_MODEL}wt103-fwd.tgz'
    WT103_BWD          = f'{S3_MODEL}wt103-bwd.tgz'

    def path(
        url:str='.', # File to download
        c_key:str='archive' # Key in `Config` where to save URL
    ) -> Path:
        "Local path where to download based on `c_key`"
        fname = url.split('/')[-1]
        local_path = URLs.LOCAL_PATH/('models' if c_key=='model' else 'data')/fname
        if local_path.exists(): return local_path
        return fastai_path(c_key)/fname
File:           ~/mambaforge/lib/python3.9/site-packages/fastai/data/external.py
Type:           type
Subclasses:     
path = untar_data(URLs.IMAGENETTE_160)

To access the image files, we can use get_image_files:

t = get_image_files(path)
t[0]
Path('/root/.fastai/data/imagenette2-160/train/n02102040/n02102040_5405.JPEG')

Or we could do the same thing using just Python's standard library, with glob:

glob?
Type:        module
String form: <module 'glob' from '/root/mambaforge/lib/python3.9/glob.py'>
File:        ~/mambaforge/lib/python3.9/glob.py
Docstring:   Filename globbing utility.
from glob import glob
files = L(glob(f'{path}/**/*.JPEG', recursive=True)).map(Path)
files[0]
Path('/root/.fastai/data/imagenette2-160/train/n02102040/n02102040_5405.JPEG')

If you look at the source for get_image_files, you'll see it uses Python's os.walk; this is a faster and more flexible function than glob, so be sure to try it out.

get_image_files??
Signature: get_image_files(path, recurse=True, folders=None)
Source:   
def get_image_files(path, recurse=True, folders=None):
    "Get image files in `path` recursively, only in `folders`, if specified."
    return get_files(path, extensions=image_extensions, recurse=recurse, folders=folders)
File:      ~/mambaforge/lib/python3.9/site-packages/fastai/data/transforms.py
Type:      function
get_files??
Signature:
get_files(
    path,
    extensions=None,
    recurse=True,
    folders=None,
    followlinks=True,
)
Source:   
def get_files(path, extensions=None, recurse=True, folders=None, followlinks=True):
    "Get all the files in `path` with optional `extensions`, optionally with `recurse`, only in `folders`, if specified."
    path = Path(path)
    folders=L(folders)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path, followlinks=followlinks)): # returns (dirpath, dirnames, filenames)
            if len(folders) !=0 and i==0: d[:] = [o for o in d if o in folders]
            else:                         d[:] = [o for o in d if not o.startswith('.')]
            if len(folders) !=0 and i==0 and '.' not in folders: continue
            res += _get_files(p, f, extensions)
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        res = _get_files(path, f, extensions)
    return L(res)
File:      ~/mambaforge/lib/python3.9/site-packages/fastai/data/transforms.py
Type:      function
os.walk?
Signature: os.walk(top, topdown=True, onerror=None, followlinks=False)
Docstring:
Directory tree generator.

For each directory in the directory tree rooted at top (including top
itself, but excluding '.' and '..'), yields a 3-tuple

    dirpath, dirnames, filenames

dirpath is a string, the path to the directory.  dirnames is a list of
the names of the subdirectories in dirpath (excluding '.' and '..').
filenames is a list of the names of the non-directory files in dirpath.
Note that the names in the lists are just names, with no path components.
To get a full path (which begins with top) to a file or directory in
dirpath, do os.path.join(dirpath, name).

If optional arg 'topdown' is true or not specified, the triple for a
directory is generated before the triples for any of its subdirectories
(directories are generated top down).  If topdown is false, the triple
for a directory is generated after the triples for all of its
subdirectories (directories are generated bottom up).

When topdown is true, the caller can modify the dirnames list in-place
(e.g., via del or slice assignment), and walk will only recurse into the
subdirectories whose names remain in dirnames; this can be used to prune the
search, or to impose a specific order of visiting.  Modifying dirnames when
topdown is false has no effect on the behavior of os.walk(), since the
directories in dirnames have already been generated by the time dirnames
itself is generated. No matter the value of topdown, the list of
subdirectories is retrieved before the tuples for the directory and its
subdirectories are generated.

By default errors from the os.scandir() call are ignored.  If
optional arg 'onerror' is specified, it should be a function; it
will be called with one argument, an OSError instance.  It can
report the error to continue with the walk, or raise the exception
to abort the walk.  Note that the filename is available as the
filename attribute of the exception object.

By default, os.walk does not follow symbolic links to subdirectories on
systems that support them.  In order to get this functionality, set the
optional argument 'followlinks' to true.

Caution:  if you pass a relative pathname for top, don't change the
current working directory between resumptions of walk.  walk never
changes the current directory, and assumes that the client doesn't
either.

Example:

import os
from os.path import join, getsize
for root, dirs, files in os.walk('python/Lib/email'):
    print(root, "consumes", end="")
    print(sum(getsize(join(root, name)) for name in files), end="")
    print("bytes in", len(files), "non-directory files")
    if 'CVS' in dirs:
        dirs.remove('CVS')  # don't visit CVS directories
File:      ~/mambaforge/lib/python3.9/os.py
Type:      function
L(os.walk('/'))
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [6], in <cell line: 1>()
----> 1 L(os.walk('/'))

File ~/mambaforge/lib/python3.9/site-packages/fastcore/foundation.py:97, in _L_Meta.__call__(cls, x, *args, **kwargs)
     95 def __call__(cls, x=None, *args, **kwargs):
     96     if not args and not kwargs and x is not None and isinstance(x,cls): return x
---> 97     return super().__call__(x, *args, **kwargs)

File ~/mambaforge/lib/python3.9/site-packages/fastcore/foundation.py:105, in L.__init__(self, items, use_list, match, *rest)
    103 def __init__(self, items=None, *rest, use_list=False, match=None):
    104     if (use_list is not None) or not is_array(items):
--> 105         items = listify(items, *rest, use_list=use_list, match=match)
    106     super().__init__(items)

File ~/mambaforge/lib/python3.9/site-packages/fastcore/basics.py:59, in listify(o, use_list, match, *rest)
     57 elif isinstance(o, list): res = o
     58 elif isinstance(o, str) or is_array(o): res = [o]
---> 59 elif is_iter(o): res = list(o)
     60 else: res = [o]
     61 if match is not None:

File ~/mambaforge/lib/python3.9/os.py:418, in _walk(top, topdown, onerror, followlinks)
    413         # Issue #23605: os.path.islink() is used instead of caching
    414         # entry.is_symlink() result during the loop on os.scandir() because
    415         # the caller can replace the directory entry during the "yield"
    416         # above.
    417         if followlinks or not islink(new_path):
--> 418             yield from _walk(new_path, topdown, onerror, followlinks)
    419 else:
    420     # Recurse into sub-directories
    421     for new_path in walk_dirs:

File ~/mambaforge/lib/python3.9/os.py:418, in _walk(top, topdown, onerror, followlinks)
    413         # Issue #23605: os.path.islink() is used instead of caching
    414         # entry.is_symlink() result during the loop on os.scandir() because
    415         # the caller can replace the directory entry during the "yield"
    416         # above.
    417         if followlinks or not islink(new_path):
--> 418             yield from _walk(new_path, topdown, onerror, followlinks)
    419 else:
    420     # Recurse into sub-directories
    421     for new_path in walk_dirs:

    [... skipping similar frames: _walk at line 418 (9 times)]

File ~/mambaforge/lib/python3.9/os.py:418, in _walk(top, topdown, onerror, followlinks)
    413         # Issue #23605: os.path.islink() is used instead of caching
    414         # entry.is_symlink() result during the loop on os.scandir() because
    415         # the caller can replace the directory entry during the "yield"
    416         # above.
    417         if followlinks or not islink(new_path):
--> 418             yield from _walk(new_path, topdown, onerror, followlinks)
    419 else:
    420     # Recurse into sub-directories
    421     for new_path in walk_dirs:

File ~/mambaforge/lib/python3.9/os.py:367, in _walk(top, topdown, onerror, followlinks)
    365 try:
    366     try:
--> 367         entry = next(scandir_it)
    368     except StopIteration:
    369         break

KeyboardInterrupt: 

We can open an image with the Python Imaging Library's Image class:

Image.open?
Signature: Image.open(fp, mode='r', formats=None)
Docstring:
Opens and identifies the given image file.

This is a lazy operation; this function identifies the file, but
the file remains open and the actual image data is not read from
the file until you try to process the data (or call the
:py:meth:`~PIL.Image.Image.load` method).  See
:py:func:`~PIL.Image.new`. See :ref:`file-handling`.

:param fp: A filename (string), pathlib.Path object or a file object.
   The file object must implement ``file.read``,
   ``file.seek``, and ``file.tell`` methods,
   and be opened in binary mode.
:param mode: The mode.  If given, this argument must be "r".
:param formats: A list or tuple of formats to attempt to load the file in.
   This can be used to restrict the set of formats checked.
   Pass ``None`` to try all supported formats. You can print the set of
   available formats by running ``python3 -m PIL`` or using
   the :py:func:`PIL.features.pilinfo` function.
:returns: An :py:class:`~PIL.Image.Image` object.
:exception FileNotFoundError: If the file cannot be found.
:exception PIL.UnidentifiedImageError: If the image cannot be opened and
   identified.
:exception ValueError: If the ``mode`` is not "r", or if a ``StringIO``
   instance is used for ``fp``.
:exception TypeError: If ``formats`` is not ``None``, a list or a tuple.
File:      ~/mambaforge/lib/python3.9/site-packages/PIL/Image.py
Type:      function
im = Image.open(files[0])
im
im_t = tensor(im)
im_t.shape
torch.Size([240, 160, 3])

That's going to be the basis of our independent variable. For our dependent variable, we can use Path.parent from pathlib. First we'll need our vocab:

files[0]
Path('/root/.fastai/data/imagenette2-160/train/n02102040/n02102040_5405.JPEG')
files[0].parent
Path('/root/.fastai/data/imagenette2-160/train/n02102040')
files[0].parent.name
'n02102040'
files.map(lambda x: x.parent.name).unique()
(#10) ['n02102040','n03445777','n03394916','n03425413','n03000684','n01440764','n03888257','n03417042','n02979186','n03028079']
files.map(lambda x: x.parent.name).unique().val2idx??
Signature: val2idx(x)
Source:   
def val2idx(x):
    "Dict from value to index"
    return {v:k for k,v in enumerate(x)}
File:      ~/mambaforge/lib/python3.9/site-packages/fastcore/basics.py
Type:      function
{v:i for i, v in enumerate(files.map(lambda x: x.parent.name).unique())}
{'n02102040': 0,
 'n03445777': 1,
 'n03394916': 2,
 'n03425413': 3,
 'n03000684': 4,
 'n01440764': 5,
 'n03888257': 6,
 'n03417042': 7,
 'n02979186': 8,
 'n03028079': 9}
lbls = files.map(Self.parent.name()).unique(); lbls
(#10) ['n02102040','n03445777','n03394916','n03425413','n03000684','n01440764','n03888257','n03417042','n02979186','n03028079']

...and the reverse mapping, thanks to L.val2idx:

v2i = lbls.val2idx(); v2i
{'n02102040': 0,
 'n03445777': 1,
 'n03394916': 2,
 'n03425413': 3,
 'n03000684': 4,
 'n01440764': 5,
 'n03888257': 6,
 'n03417042': 7,
 'n02979186': 8,
 'n03028079': 9}

That's all the pieces we need to put together our Dataset.

Dataset

A Dataset in PyTorch can be anything that supports indexing (__getitem__) and len:

class Dataset:
    def __init__(self, fns): self.fns=fns
    def __len__(self): return len(self.fns)
    def __getitem__(self, i):
        im = Image.open(self.fns[i]).resize((64,64)).convert('RGB')
        y = v2i[self.fns[i].parent.name]
        return tensor(im).float()/255, tensor(y)
dset = Dataset(get_image_files(path)[:3])
len(dset)
3
X, y = dset[0]
X.shape, y
(torch.Size([64, 64, 3]), tensor(0))

We need a list of training and validation filenames to pass to Dataset.__init__:

set(o.parent.parent.name for o in files)
{'train', 'val'}
train_filt = L(o.parent.parent.name=='train' for o in files)
train,valid = files[train_filt],files[~train_filt]
len(train),len(valid)
(9469, 3925)

Now we can try it out:

train_ds,valid_ds = Dataset(train),Dataset(valid)
x,y = train_ds[0]
x.shape,y
(torch.Size([64, 64, 3]), tensor(0))
show_image(x, title=lbls[y]);

As you see, our dataset is returning the independent and dependent variables as a tuple, which is just what we need. We'll need to be able to collate these into a mini-batch. Generally this is done with torch.stack, which is what we'll use here:

tmp = [(o, i) for i, o in enumerate("A B C D E F".split())]
a, b = zip(*tmp)
a, b
(('A', 'B', 'C', 'D', 'E', 'F'), (0, 1, 2, 3, 4, 5))
def collate(idxs, ds): 
    xb,yb = zip(*[ds[i] for i in idxs])
    return torch.stack(xb),torch.stack(yb)
class Dataset0:
    def __init__(self, l): self.l=l
    def __len__(self): return len(self.l)
    def __getitem__(self, i):
        return torch.tensor(ord(self.l[i][0])), torch.tensor(self.l[i][1])
    
collate([0,1,2],
        Dataset0([(o, i) for i, o in enumerate("a b c d e f".split())]))
(tensor([97, 98, 99]), tensor([0, 1, 2]))

Here's a mini-batch with two items, for testing our collate:

x,y = collate([1,2], train_ds)
x.shape,y
(torch.Size([2, 64, 64, 3]), tensor([0, 0]))

Now that we have a dataset and a collation function, we're ready to create DataLoader. We'll add two more things here: an optional shuffle for the training set, and a ProcessPoolExecutor to do our preprocessing in parallel. A parallel data loader is very important, because opening and decoding a JPEG image is a slow process. One CPU core is not enough to decode images fast enough to keep a modern GPU busy. Here's our DataLoader class:

L.range(9).shuffle()
(#9) [3,6,7,4,8,2,5,0,1]
class DataLoader:
    def __init__(self, ds, bs=128, shuffle=False, n_workers=1):
        self.ds,self.bs,self.shuffle,self.n_workers = ds,bs,shuffle,n_workers

    def __len__(self): return (len(self.ds)-1)//self.bs+1

    def __iter__(self):
        idxs = L.range(self.ds)
        if self.shuffle: idxs = idxs.shuffle()
        chunks = [idxs[n:n+self.bs] for n in range(0, len(self.ds), self.bs)]
        with ProcessPoolExecutor(self.n_workers) as ex:
            yield from ex.map(collate, chunks, ds=self.ds)

Let's try it out with our training and validation datasets:

defaults.cpus
8
n_workers = min(16, defaults.cpus)
%time train_dl = DataLoader(train_ds, bs=128, shuffle=True, n_workers=n_workers)
valid_dl = DataLoader(valid_ds, bs=256, shuffle=False, n_workers=n_workers)
xb,yb = first(train_dl)
xb.shape,yb.shape,len(train_dl)
CPU times: user 7 µs, sys: 4 µs, total: 11 µs
Wall time: 15.7 µs
(torch.Size([128, 64, 64, 3]), torch.Size([128]), 74)

This data loader is not much slower than PyTorch's, but it's far simpler. So if you're debugging a complex data loading process, don't be afraid to try doing things manually to help you see exactly what's going on.

For normalization, we'll need image statistics. Generally it's fine to calculate these on a single training mini-batch, since precision isn't needed here:

stats = [xb.mean((0,1,2)),xb.std((0,1,2))]
stats
[tensor([0.4595, 0.4560, 0.4263]), tensor([0.2685, 0.2648, 0.2886])]

Our Normalize class just needs to store these stats and apply them (to see why the to_device is needed, try commenting it out, and see what happens later in this notebook):

class Normalize:
    def __init__(self, stats): self.stats=stats
    def __call__(self, x):
        if x.device != self.stats[0].device:
        #    self.stats = to_device(self.stats, x.device)
            print(x.device, self.stats[0].device)
        return (x-self.stats[0])/self.stats[1]

We always like to test everything we build in a notebook, as soon as we build it:

torch.arange(2*3*4*5).reshape(2,3,4,5).permute((0,3,1,2)).shape # swap axises
torch.Size([2, 5, 3, 4])
x.shape, x.permute(0,3,1,2).shape # -> (N, C, H, W)
(torch.Size([2, 64, 64, 3]), torch.Size([2, 3, 64, 64]))
norm = Normalize(stats)
def tfm_x(x): return norm(x).permute((0,3,1,2))
t = tfm_x(x)
t.mean((0,2,3)),t.std((0,2,3))
(tensor([-0.1056, -0.1291,  0.1143]), tensor([0.8184, 0.8226, 0.7869]))

Here tfm_x isn't just applying Normalize, but is also permuting the axis order from NHWC to NCHW (see <> if you need a reminder of what these acronyms refer to).</p> </div> </div> </div>

PIL uses HWC axis order, which we can't use with PyTorch, hence the need for this permute.

That's all we need for the data for our model. So now we need the model itself!

Module and Parameter

To create a model, we'll need Module. To create Module, we'll need Parameter, so let's start there. Recall that in <> we said that the Parameter class "doesn't actually add any functionality (other than automatically calling requires_grad_ for us). It's only used as a "marker" to show what to include in parameters." Here's a definition which does exactly that:</p> </div> </div> </div>

class Parameter(Tensor):
    def __new__(self, x): return Tensor._make_subclass(Parameter, x, True)
    def __init__(self, *args, **kwargs): self.requires_grad_()

The implementation here is a bit awkward: we have to define the special __new__ Python method and use the internal PyTorch method _make_subclass because, as at the time of writing, PyTorch doesn't otherwise work correctly with this kind of subclassing or provide an officially supported API to do this. This may have been fixed by the time you read this, so look on the book's website to see if there are updated details.

Our Parameter now behaves just like a tensor, as we wanted:

Parameter(tensor(3.))
tensor(3., requires_grad=True)

Now that we have this, we can define Module:

class Module:
    def __init__(self):
        self.hook,self.params,self.children,self._training = None,[],[],False
        
    def register_parameters(self, *ps): self.params += ps
    def register_modules   (self, *ms): self.children += ms
        
    @property
    def training(self): return self._training
    @training.setter
    def training(self,v):
        self._training = v
        for m in self.children: m.training=v
            
    def parameters(self):
        return self.params + sum([m.parameters() for m in self.children], [])

    def __setattr__(self,k,v):
        super().__setattr__(k,v)
        if isinstance(v,Parameter): self.register_parameters(v)
        if isinstance(v,Module):    self.register_modules(v)
        
    def __call__(self, *args, **kwargs):
        res = self.forward(*args, **kwargs)
        if self.hook is not None: self.hook(res, args)
        return res
    
    def cuda(self):
        for p in self.parameters(): p.data = p.data.cuda()

The key functionality is in the definition of parameters:

self.params + sum([m.parameters() for m in self.children], [])

This means that we can ask any Module for its parameters, and it will return them, including all its child modules (recursively). But how does it know what its parameters are? It's thanks to implementing Python's special __setattr__ method, which is called for us any time Python sets an attribute on a class. Our implementation includes this line:

if isinstance(v,Parameter): self.register_parameters(v)

As you see, this is where we use our new Parameter class as a "marker"—anything of this class is added to our params.

Python's __call__ allows us to define what happens when our object is treated as a function; we just call forward (which doesn't exist here, so it'll need to be added by subclasses). Before we do, we'll call a hook, if it's defined. Now you can see that PyTorch hooks aren't doing anything fancy at all—they're just calling any hooks that have been registered.

Other than these pieces of functionality, our Module also provides cuda and training attributes, which we'll use shortly.

Now we can create our first Module, which is ConvLayer:

class ConvLayer(Module):
    def __init__(self, ni, nf, stride=1, bias=True, act=True):
        super().__init__()
        self.w = Parameter(torch.zeros(nf,ni,3,3))
        self.b = Parameter(torch.zeros(nf)) if bias else None
        self.act,self.stride = act,stride
        init = nn.init.kaiming_normal_ if act else nn.init.xavier_normal_
        init(self.w)
    
    def forward(self, x):
        x = F.conv2d(x, self.w, self.b, stride=self.stride, padding=1)
        if self.act: x = F.relu(x)
        return x

We're not implementing F.conv2d from scratch, since you should have already done that (using unfold) in the questionnaire in <>. Instead, we're just creating a small class that wraps it up along with bias and weight initialization. Let's check that it works correctly with Module.parameters:</p> </div> </div> </div>

l = ConvLayer(3, 4)
l.parameters()[0].shape, l.parameters()[1].shape 
(torch.Size([4, 3, 3, 3]), torch.Size([4]))

And that we can call it (which will result in forward being called):

xbt = tfm_x(xb)
xb.shape, xbt.shape
(torch.Size([128, 64, 64, 3]), torch.Size([128, 3, 64, 64]))
r = l(xbt)
r.shape
torch.Size([128, 4, 64, 64])

In the same way, we can implement Linear:

class Linear(Module):
    def __init__(self, ni, nf):
        super().__init__()
        self.w = Parameter(torch.zeros(nf,ni))
        self.b = Parameter(torch.zeros(nf))
        nn.init.xavier_normal_(self.w)
    
    def forward(self, x): return x@self.w.t() + self.b

and test if it works:

l = Linear(4,2)
r = l(torch.ones(3,4))
r.shape
torch.Size([3, 2])

Let's also create a testing module to check that if we include multiple parameters as attributes, they are all correctly registered:

class T(Module):
    def __init__(self):
        super().__init__()
        self.c,self.l = ConvLayer(3,4),Linear(4,2)

Since we have a conv layer and a linear layer, each of which has weights and biases, we'd expect four parameters in total:

t = T()
t.children, len(t.children[0].parameters()), len(t.children[1].parameters())
([<__main__.ConvLayer at 0x7f9518d16520>, <__main__.Linear at 0x7f95192ae3a0>],
 2,
 2)
len(t.parameters())
4

We should also find that calling cuda on this class puts all these parameters on the GPU:

t.l.w.device
device(type='cpu')
t.cuda()
t.l.w.device
device(type='cuda', index=0)

We can now use those pieces to create a CNN.

Simple CNN

As we've seen, a Sequential class makes many architectures easier to implement, so let's make one:

class Sequential(Module):
    def __init__(self, *layers):
        super().__init__()
        self.layers = layers
        self.register_modules(*layers)

    def forward(self, x):
        for l in self.layers: x = l(x)
        return x

The forward method here just calls each layer in turn. Note that we have to use the register_modules method we defined in Module, since otherwise the contents of layers won't appear in parameters.

Important: All The Code is Here: Remember that we’re not using any PyTorch functionality for modules here; we’re defining everything ourselves. So if you’re not sure what register_modules does, or why it’s needed, have another look at our code for Module to see what we wrote!

We can create a simplified AdaptivePool that only handles pooling to a 1×1 output, and flattens it as well, by just using mean:

class AdaptivePool(Module):
    def forward(self, x): return x.mean((2,3))

That's enough for us to create a CNN!

def simple_cnn():
    return Sequential(
        ConvLayer(3 ,16 ,stride=2), #32
        ConvLayer(16,32 ,stride=2), #16
        ConvLayer(32,64 ,stride=2), # 8
        ConvLayer(64,128,stride=2), # 4
        AdaptivePool(),
        Linear(128, 10)
    )

Let's see if our parameters are all being registered correctly:

m = simple_cnn()
len(m.parameters())
10
[len(o.parameters()) for o in m.children]
[2, 2, 2, 2, 0, 2]
sum([len(o.parameters()) for o in m.children])
10

Now we can try adding a hook. Note that we've only left room for one hook in Module; you could make it a list, or use something like Pipeline to run a few as a single function:

def print_stats(outp, inp): print (outp.mean().item(),outp.std().item())
for i in range(4): m.layers[i].hook = print_stats

r = m(xbt)
r.shape
0.5291942358016968 0.8697973489761353
0.4359581172466278 0.825819730758667
0.4345751404762268 0.7494376301765442
0.46103182435035706 0.7244757413864136
torch.Size([128, 10])
m.layers
(<__main__.ConvLayer at 0x7fec32275220>,
 <__main__.ConvLayer at 0x7fec32275460>,
 <__main__.ConvLayer at 0x7fec32275d00>,
 <__main__.ConvLayer at 0x7fec32275e20>,
 <__main__.AdaptivePool at 0x7fec32268d30>,
 <__main__.Linear at 0x7fec322754f0>)
m.children
[<__main__.ConvLayer at 0x7fec32275220>,
 <__main__.ConvLayer at 0x7fec32275460>,
 <__main__.ConvLayer at 0x7fec32275d00>,
 <__main__.ConvLayer at 0x7fec32275e20>,
 <__main__.AdaptivePool at 0x7fec32268d30>,
 <__main__.Linear at 0x7fec322754f0>]
m
<__main__.Sequential at 0x7fec32275af0>
def print_stats(outp, inp): print (outp.mean().item(),outp.std().item())
for o in m.layers: o.hook = print_stats

r = m(xbt)
r.shape
0.5291942358016968 0.8697973489761353
0.4359581172466278 0.825819730758667
0.4345751404762268 0.7494376301765442
0.46103182435035706 0.7244757413864136
0.46103185415267944 0.5164786577224731
0.4029931128025055 0.9249605536460876
torch.Size([128, 10])
def print_stats(outp, inp):  if self.act: print(act)
for o in m.layers: o.hook = print_stats

r = m(xbt)
r.shape

We have data and model. Now we need a loss function.

Loss

We've already seen how to define "negative log likelihood":

def nll(input, target): return -input[range(target.shape[0]), target].mean()

Well actually, there's no log here, since we're using the same definition as PyTorch. That means we need to put the log together with softmax:

batch_size = 5
n_classes = 3
input = torch.randn(batch_size, n_classes).relu()
input
tensor([[0.0000, 0.0000, 0.0000],
        [0.8550, 0.0000, 0.5435],
        [0.0000, 2.5415, 0.0000],
        [0.0000, 1.0593, 0.0000],
        [0.2269, 0.0000, 0.0000]])
target = torch.randint(0, n_classes, (batch_size,))
target
tensor([1, 0, 0, 2, 2])
-input[range(batch_size), target]
tensor([-0.0000, -0.8550, -0.0000, -0.0000, -0.0000])
-input[range(batch_size), target].mean()
tensor(-0.1710)
nll(input, target)
tensor(-0.1710)

Well actually, there's no log here, since we're using the same definition as PyTorch. That means we need to put the log together with softmax:

def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()

sm = log_softmax(r); sm[0][0]
tensor(-1.8825, grad_fn=<AliasBackward0>)

Combining these gives us our cross-entropy loss:

loss = nll(sm, yb)
loss
tensor(2.4390, grad_fn=<AliasBackward0>)

Note that the formula:

$$\log \left ( \frac{a}{b} \right ) = \log(a) - \log(b)$$

gives a simplification when we compute the log softmax, which was previously defined as (x.exp()/(x.exp().sum(-1))).log():

def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()
sm = log_softmax(r); sm[0][0]
tensor(-0.0448, grad_fn=<AliasBackward0>)

Then, there is a more stable way to compute the log of the sum of exponentials, called the LogSumExp trick. The idea is to use the following formula:

$$\log \left ( \sum_{j=1}^{n} e^{x_{j}} \right ) = \log \left ( e^{a} \sum_{j=1}^{n} e^{x_{j}-a} \right ) = a + \log \left ( \sum_{j=1}^{n} e^{x_{j}-a} \right )$$

where $a$ is the maximum of $x_{j}$.

Here's the same thing in code:

x = torch.rand(5)
a = x.max()
x.exp().sum().log() == a + (x-a).exp().sum().log()
tensor(True)

We'll put that into a function:

def logsumexp(x):
    m = x.max(-1)[0]
    return m + (x-m[:,None]).exp().sum(-1).log()

logsumexp(r)[0]
tensor(0.9535, grad_fn=<AliasBackward0>)

so we can use it for our log_softmax function:

def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)

Which gives the same result as before:

sm = log_softmax(r); sm[0][0]
tensor(-0.0448, grad_fn=<AliasBackward0>)

We can use these to create cross_entropy:

def cross_entropy(preds, yb): return nll(log_softmax(preds), yb).mean()

Let's now combine all those pieces together to create a Learner.

Learner

We have data, a model, and a loss function; we only need one more thing before we can fit a model, and that's an optimizer! Here's SGD:

class SGD:
    def __init__(self, params, lr, wd=0.): store_attr()
    def step(self):
        for p in self.params:
            p.data -= (p.grad.data + p.data*self.wd) * self.lr
            p.grad.data.zero_()

As we've seen in this book, life is easier with a Learner. The Learner class needs to know our training and validation sets, which means we need DataLoaders to store them. We don't need any other functionality, just a place to store them and access them:

class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls

dls = DataLoaders(train_dl,valid_dl)

Now we're ready to create our Learner class:

store_attr??
Signature:
store_attr(
    names=None,
    self=None,
    but='',
    cast=False,
    store_args=None,
    **attrs,
)
Source:   
def store_attr(names=None, self=None, but='', cast=False, store_args=None, **attrs):
    "Store params named in comma-separated `names` from calling context into attrs in `self`"
    fr = sys._getframe(1)
    args = argnames(fr, True)
    if self: args = ('self', *args)
    else: self = fr.f_locals[args[0]]
    if store_args is None: store_args = not hasattr(self,'__slots__')
    if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}
    anno = annotations(self) if cast else {}
    if names and isinstance(names,str): names = re.split(', *', names)
    ns = names if names is not None else getattr(self, '__slots__', args[1:])
    added = {n:fr.f_locals[n] for n in ns}
    attrs = {**attrs, **added}
    if isinstance(but,str): but = re.split(', *', but)
    attrs = {k:v for k,v in attrs.items() if k not in but}
    return _store_attr(self, anno, **attrs)
File:      ~/mambaforge/lib/python3.9/site-packages/fastcore/basics.py
Type:      function
self?
Object `self` not found.
class Learner:
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):
        store_attr()
        for cb in cbs: cb.learner = self

    def one_batch(self):
        self('before_batch')
        xb,yb = self.batch
        self.preds = self.model(xb)
        self.loss = self.loss_func(self.preds, yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
        self('after_batch')

    def one_epoch(self, train):
        self.model.training = train
        self('before_epoch')
        dl = self.dls.train if train else self.dls.valid
        #def progress_bar(dl, leave): return dl
        for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):
            self.one_batch()
        self('after_epoch')
    
    def fit(self, n_epochs):
        self('before_fit')
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        try:
            for self.epoch in range(n_epochs):
                self.one_epoch(True)
                self.one_epoch(False)
        except CancelFitException: pass
        self('after_fit')
        
    def __call__(self,name):
        for cb in self.cbs: getattr(cb,name,noop)()

This is the largest class we've created in the book, but each method is quite small, so by looking at each in turn you should be able to follow what's going on.

The main method we'll be calling is fit. This loops with:

for self.epoch in range(n_epochs)

and at each epoch calls self.one_epoch for each of train=True and then train=False. Then self.one_epoch calls self.one_batch for each batch in dls.train or dls.valid, as appropriate (after wrapping the DataLoader in fastprogress.progress_bar. Finally, self.one_batch follows the usual set of steps to fit one mini-batch that we've seen throughout this book.

Before and after each step, Learner calls self, which calls __call__ (which is standard Python functionality). __call__ uses getattr(cb,name) on each callback in self.cbs, which is a Python built-in function that returns the attribute (a method, in this case) with the requested name. So, for instance, self('before_fit') will call cb.before_fit() for each callback where that method is defined.

As you can see, Learner is really just using our standard training loop, except that it's also calling callbacks at appropriate times. So let's define some callbacks!

Callbacks

In Learner.__init__ we have:

for cb in cbs: cb.learner = self

In other words, every callback knows what learner it is used in. This is critical, since otherwise a callback can't get information from the learner, or change things in the learner. Because getting information from the learner is so common, we make that easier by defining Callback as a subclass of GetAttr, with a default attribute of learner:

class Callback(GetAttr): _default='learner'

GetAttr is a fastai class that implements Python's standard __getattr__ and __dir__ methods for you, such that any time you try to access an attribute that doesn't exist, it passes the request along to whatever you have defined as _default.

For instance, we want to move all model parameters to the GPU automatically at the start of fit. We could do this by defining before_fit as self.learner.model.cuda(); however, because learner is the default attribute, and we have SetupLearnerCB inherit from Callback (which inherits from GetAttr), we can remove the .learner and just call self.model.cuda():

class SetupLearnerCB(Callback):
    def before_batch(self):
        xb,yb = to_device(self.batch)
        self.learner.batch = tfm_x(xb),yb

    def before_fit(self): self.model.cuda()

In SetupLearnerCB we also move each mini-batch to the GPU, by calling to_device(self.batch) (we could also have used the longer to_device(self.learner.batch). Note however that in the line self.learner.batch = tfm_x(xb),yb we can't remove .learner, because here we're setting the attribute, not getting it.

Before we try our Learner out, let's create a callback to track and print progress. Otherwise we won't really know if it's working properly:

class TrackResults(Callback):
    def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]
        
    def after_epoch(self):
        n = sum(self.ns)
        print(self.epoch, self.model.training,
              sum(self.losses).item()/n, sum(self.accs).item()/n)
        
    def after_batch(self):
        xb,yb = self.batch
        acc = (self.preds.argmax(dim=1)==yb).float().sum()
        self.accs.append(acc)
        n = len(xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

Now we're ready to use our Learner for the first time!

cbs = [SetupLearnerCB(),TrackResults()]
learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs)
learn.fit(1)
0.00% [0/74 00:00<00:00]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [50], in DataLoader.__iter__(self)
     11 with ProcessPoolExecutor(self.n_workers) as ex:
---> 12     yield from ex.map(collate, chunks, ds=self.ds)

File ~/mambaforge/lib/python3.9/concurrent/futures/process.py:562, in _chain_from_iterable_of_lists(iterable)
    557 """
    558 Specialized implementation of itertools.chain.from_iterable.
    559 Each item in *iterable* should be a list.  This function is
    560 careful not to keep references to yielded objects.
    561 """
--> 562 for element in iterable:
    563     element.reverse()

File ~/mambaforge/lib/python3.9/concurrent/futures/_base.py:609, in Executor.map.<locals>.result_iterator()
    608 if timeout is None:
--> 609     yield fs.pop().result()
    610 else:

File ~/mambaforge/lib/python3.9/concurrent/futures/_base.py:441, in Future.result(self, timeout)
    439     return self.__get_result()
--> 441 self._condition.wait(timeout)
    443 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File ~/mambaforge/lib/python3.9/threading.py:312, in Condition.wait(self, timeout)
    311 if timeout is None:
--> 312     waiter.acquire()
    313     gotit = True

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
Input In [103], in <cell line: 3>()
      1 cbs = [SetupLearnerCB(),TrackResults()]
      2 learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs)
----> 3 learn.fit(1)

Input In [99], in Learner.fit(self, n_epochs)
     29 try:
     30     for self.epoch in range(n_epochs):
---> 31         self.one_epoch(True)
     32         self.one_epoch(False)
     33 except CancelFitException: pass

Input In [99], in Learner.one_epoch(self, train)
     19 dl = self.dls.train if train else self.dls.valid
     20 #def progress_bar(dl, leave): return dl
---> 21 for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):
     22     self.one_batch()
     23 self('after_epoch')

File ~/mambaforge/lib/python3.9/site-packages/fastprogress/fastprogress.py:41, in ProgressBar.__iter__(self)
     39 if self.total != 0: self.update(0)
     40 try:
---> 41     for i,o in enumerate(self.gen):
     42         if i >= self.total: break
     43         yield o

Input In [50], in DataLoader.__iter__(self)
     10 chunks = [idxs[n:n+self.bs] for n in range(0, len(self.ds), self.bs)]
     11 with ProcessPoolExecutor(self.n_workers) as ex:
---> 12     yield from ex.map(collate, chunks, ds=self.ds)

File ~/mambaforge/lib/python3.9/concurrent/futures/_base.py:637, in Executor.__exit__(self, exc_type, exc_val, exc_tb)
    636 def __exit__(self, exc_type, exc_val, exc_tb):
--> 637     self.shutdown(wait=True)
    638     return False

File ~/mambaforge/lib/python3.9/concurrent/futures/process.py:767, in ProcessPoolExecutor.shutdown(self, wait, cancel_futures)
    764         self._executor_manager_thread_wakeup.wakeup()
    766 if self._executor_manager_thread is not None and wait:
--> 767     self._executor_manager_thread.join()
    768 # To reduce the risk of opening too many files, remove references to
    769 # objects that use file descriptors.
    770 self._executor_manager_thread = None

File ~/mambaforge/lib/python3.9/threading.py:1060, in Thread.join(self, timeout)
   1057     raise RuntimeError("cannot join current thread")
   1059 if timeout is None:
-> 1060     self._wait_for_tstate_lock()
   1061 else:
   1062     # the behavior of a negative timeout isn't documented, but
   1063     # historically .join(timeout=x) for x<0 has acted as if timeout=0
   1064     self._wait_for_tstate_lock(timeout=max(timeout, 0))

File ~/mambaforge/lib/python3.9/threading.py:1080, in Thread._wait_for_tstate_lock(self, block, timeout)
   1077     return
   1079 try:
-> 1080     if lock.acquire(block, timeout):
   1081         lock.release()
   1082         self._stop()

KeyboardInterrupt: 

It's quite amazing to realize that we can implement all the key ideas from fastai's Learner in so little code! Let's now add some learning rate scheduling.

Scheduling the Learning Rate

If we're going to get good results, we'll want an LR finder and 1cycle training. These are both annealing callbacks—that is, they are gradually changing hyperparameters as we train. Here's LRFinder:

class LRFinder(Callback):
    def before_fit(self):
        self.losses,self.lrs = [],[]
        self.learner.lr = 1e-6
        
    def before_batch(self):
        if not self.model.training: return
        self.opt.lr *= 1.2

    def after_batch(self):
        if not self.model.training: return
        if self.opt.lr>10 or torch.isnan(self.loss): raise CancelFitException
        self.losses.append(self.loss.item())
        self.lrs.append(self.opt.lr)

This shows how we're using CancelFitException, which is itself an empty class, only used to signify the type of exception. You can see in Learner that this exception is caught. (You should add and test CancelBatchException, CancelEpochException, etc. yourself.) Let's try it out, by adding it to our list of callbacks:

lrfind = LRFinder()
learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[lrfind])
learn.fit(2)
0.00% [0/74 00:00<00:00]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [50], in DataLoader.__iter__(self)
     11 with ProcessPoolExecutor(self.n_workers) as ex:
---> 12     yield from ex.map(collate, chunks, ds=self.ds)

File ~/mambaforge/lib/python3.9/concurrent/futures/process.py:562, in _chain_from_iterable_of_lists(iterable)
    557 """
    558 Specialized implementation of itertools.chain.from_iterable.
    559 Each item in *iterable* should be a list.  This function is
    560 careful not to keep references to yielded objects.
    561 """
--> 562 for element in iterable:
    563     element.reverse()

File ~/mambaforge/lib/python3.9/concurrent/futures/_base.py:609, in Executor.map.<locals>.result_iterator()
    608 if timeout is None:
--> 609     yield fs.pop().result()
    610 else:

File ~/mambaforge/lib/python3.9/concurrent/futures/_base.py:441, in Future.result(self, timeout)
    439     return self.__get_result()
--> 441 self._condition.wait(timeout)
    443 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File ~/mambaforge/lib/python3.9/threading.py:312, in Condition.wait(self, timeout)
    311 if timeout is None:
--> 312     waiter.acquire()
    313     gotit = True

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
Input In [105], in <cell line: 3>()
      1 lrfind = LRFinder()
      2 learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[lrfind])
----> 3 learn.fit(2)

Input In [99], in Learner.fit(self, n_epochs)
     29 try:
     30     for self.epoch in range(n_epochs):
---> 31         self.one_epoch(True)
     32         self.one_epoch(False)
     33 except CancelFitException: pass

Input In [99], in Learner.one_epoch(self, train)
     19 dl = self.dls.train if train else self.dls.valid
     20 #def progress_bar(dl, leave): return dl
---> 21 for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):
     22     self.one_batch()
     23 self('after_epoch')

File ~/mambaforge/lib/python3.9/site-packages/fastprogress/fastprogress.py:41, in ProgressBar.__iter__(self)
     39 if self.total != 0: self.update(0)
     40 try:
---> 41     for i,o in enumerate(self.gen):
     42         if i >= self.total: break
     43         yield o

Input In [50], in DataLoader.__iter__(self)
     10 chunks = [idxs[n:n+self.bs] for n in range(0, len(self.ds), self.bs)]
     11 with ProcessPoolExecutor(self.n_workers) as ex:
---> 12     yield from ex.map(collate, chunks, ds=self.ds)

File ~/mambaforge/lib/python3.9/concurrent/futures/_base.py:637, in Executor.__exit__(self, exc_type, exc_val, exc_tb)
    636 def __exit__(self, exc_type, exc_val, exc_tb):
--> 637     self.shutdown(wait=True)
    638     return False

File ~/mambaforge/lib/python3.9/concurrent/futures/process.py:767, in ProcessPoolExecutor.shutdown(self, wait, cancel_futures)
    764         self._executor_manager_thread_wakeup.wakeup()
    766 if self._executor_manager_thread is not None and wait:
--> 767     self._executor_manager_thread.join()
    768 # To reduce the risk of opening too many files, remove references to
    769 # objects that use file descriptors.
    770 self._executor_manager_thread = None

File ~/mambaforge/lib/python3.9/threading.py:1060, in Thread.join(self, timeout)
   1057     raise RuntimeError("cannot join current thread")
   1059 if timeout is None:
-> 1060     self._wait_for_tstate_lock()
   1061 else:
   1062     # the behavior of a negative timeout isn't documented, but
   1063     # historically .join(timeout=x) for x<0 has acted as if timeout=0
   1064     self._wait_for_tstate_lock(timeout=max(timeout, 0))

File ~/mambaforge/lib/python3.9/threading.py:1080, in Thread._wait_for_tstate_lock(self, block, timeout)
   1077     return
   1079 try:
-> 1080     if lock.acquire(block, timeout):
   1081         lock.release()
   1082         self._stop()

KeyboardInterrupt: 

And take a look at the results:

plt.plot(lrfind.lrs[:-2],lrfind.losses[:-2])
plt.xscale('log')

Now we can define our OneCycle training callback:

class OneCycle(Callback):
    def __init__(self, base_lr): self.base_lr = base_lr
    def before_fit(self): self.lrs = []

    def before_batch(self):
        if not self.model.training: return
        n = len(self.dls.train)
        bn = self.epoch*n + self.num
        mn = self.n_epochs*n
        pct = bn/mn
        pct_start,div_start = 0.25,10
        if pct<pct_start:
            pct /= pct_start
            lr = (1-pct)*self.base_lr/div_start + pct*self.base_lr
        else:
            pct = (pct-pct_start)/(1-pct_start)
            lr = (1-pct)*self.base_lr
        self.opt.lr = lr
        self.lrs.append(lr)

We'll try an LR of 0.1:

onecyc = OneCycle(0.1)
learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[onecyc])

Let's fit for a while and see how it looks (we won't show all the output in the book—try it in the notebook to see the results):

learn.fit(8)

Finally, we'll check that the learning rate followed the schedule we defined (as you see, we're not using cosine annealing here):

plt.plot(onecyc.lrs);

Conclusion

We have explored how the key concepts of the fastai library are implemented by re-implementing them in this chapter. Since it's mostly full of code, you should definitely try to experiment with it by looking at the corresponding notebook on the book's website. Now that you know how it's built, as a next step be sure to check out the intermediate and advanced tutorials in the fastai documentation to learn how to customize every bit of the library.

Questionnaire

Tip: Experiments: For the questions here that ask you to explain what some function or class is, you should also complete your own code experiments.

  1. What is glob?
  2. How do you open an image with the Python imaging library?
  3. What does L.map do?
  4. What does Self do?
  5. What is L.val2idx?
  6. What methods do you need to implement to create your own Dataset?
  7. Why do we call convert when we open an image from Imagenette?
  8. What does ~ do? How is it useful for splitting training and validation sets?
  9. Does ~ work with the L or Tensor classes? What about NumPy arrays, Python lists, or pandas DataFrames?
  10. What is ProcessPoolExecutor?
  11. How does L.range(self.ds) work?
  12. What is __iter__?
  13. What is first?
  14. What is permute? Why is it needed?
  15. What is a recursive function? How does it help us define the parameters method?
  16. Write a recursive function that returns the first 20 items of the Fibonacci sequence.
  17. What is super?
  18. Why do subclasses of Module need to override forward instead of defining __call__?
  19. In ConvLayer, why does init depend on act?
  20. Why does Sequential need to call register_modules?
  21. Write a hook that prints the shape of every layer's activations.
  22. What is "LogSumExp"?
  23. Why is log_softmax useful?
  24. What is GetAttr? How is it helpful for callbacks?
  25. Reimplement one of the callbacks in this chapter without inheriting from Callback or GetAttr.
  26. What does Learner.__call__ do?
  27. What is getattr? (Note the case difference to GetAttr!)
  28. Why is there a try block in fit?
  29. Why do we check for model.training in one_batch?
  30. What is store_attr?
  31. What is the purpose of TrackResults.before_epoch?
  32. What does model.cuda do? How does it work?
  33. Why do we need to check model.training in LRFinder and OneCycle?
  34. Use cosine annealing in OneCycle.

Further Research

  1. Write resnet18 from scratch (refer to <> as needed), and train it with the Learner in this chapter.</li>
  2. Implement a batchnorm layer from scratch and use it in your resnet18.
  3. Write a Mixup callback for use in this chapter.
  4. Add momentum to SGD.
  5. Pick a few features that you're interested in from fastai (or any other library) and implement them in this chapter.
  6. Pick a research paper that's not yet implemented in fastai or PyTorch and implement it in this chapter.
    • Port it over to fastai.
    • Submit a pull request to fastai, or create your own extension module and release it.
    • Hint: you may find it helpful to use nbdev to create and deploy your package.
  7. </ol> </div> </div> </div> </div>