from fastai.imports import *
from fastai.tabular.all import *
from kaggle import api
data = 'camnugent/california-housing-prices'
api.dataset_list_cli(search=data)
ref                                  title                       size  lastUpdated          downloadCount  voteCount  usabilityRating  
-----------------------------------  -------------------------  -----  -------------------  -------------  ---------  ---------------  
camnugent/california-housing-prices  California Housing Prices  400KB  2017-11-24 03:14:59          70928        741  0.85294116       
path = Path(data.split('/')[1])
api.dataset_download_files(data, path)
path.ls()
(#2) [Path('california-housing-prices/california-housing-prices.zip'),Path('california-housing-prices/housing.csv')]
import zipfile
zipfile.ZipFile(path.ls()[0]).extractall(path)
path.ls()
(#2) [Path('california-housing-prices/california-housing-prices.zip'),Path('california-housing-prices/housing.csv')]
df = pd.read_csv(path/'housing.csv', low_memory=False)
df.head()
longitude latitude housing_median_age total_rooms total_bedrooms population households median_income median_house_value ocean_proximity
0 -122.23 37.88 41.0 880.0 129.0 322.0 126.0 8.3252 452600.0 NEAR BAY
1 -122.22 37.86 21.0 7099.0 1106.0 2401.0 1138.0 8.3014 358500.0 NEAR BAY
2 -122.24 37.85 52.0 1467.0 190.0 496.0 177.0 7.2574 352100.0 NEAR BAY
3 -122.25 37.85 52.0 1274.0 235.0 558.0 219.0 5.6431 341300.0 NEAR BAY
4 -122.25 37.85 52.0 1627.0 280.0 565.0 259.0 3.8462 342200.0 NEAR BAY
df.columns
Index(['longitude', 'latitude', 'housing_median_age', 'total_rooms',
       'total_bedrooms', 'population', 'households', 'median_income',
       'median_house_value', 'ocean_proximity'],
      dtype='object')
df.hist(figsize=(12,12))
array([[<AxesSubplot:title={'center':'longitude'}>,
        <AxesSubplot:title={'center':'latitude'}>,
        <AxesSubplot:title={'center':'housing_median_age'}>],
       [<AxesSubplot:title={'center':'total_rooms'}>,
        <AxesSubplot:title={'center':'total_bedrooms'}>,
        <AxesSubplot:title={'center':'population'}>],
       [<AxesSubplot:title={'center':'households'}>,
        <AxesSubplot:title={'center':'median_income'}>,
        <AxesSubplot:title={'center':'median_house_value'}>]],
      dtype=object)
x = "total_rooms total_bedrooms population households".split()
df[x] = np.log(df[x])
df.hist(figsize=(12,12))
array([[<AxesSubplot:title={'center':'longitude'}>,
        <AxesSubplot:title={'center':'latitude'}>,
        <AxesSubplot:title={'center':'housing_median_age'}>],
       [<AxesSubplot:title={'center':'total_rooms'}>,
        <AxesSubplot:title={'center':'total_bedrooms'}>,
        <AxesSubplot:title={'center':'population'}>],
       [<AxesSubplot:title={'center':'households'}>,
        <AxesSubplot:title={'center':'median_income'}>,
        <AxesSubplot:title={'center':'median_house_value'}>]],
      dtype=object)
cat = ['ocean_proximity']
cont = ['longitude', 'latitude', 'housing_median_age', 'total_rooms', 'total_bedrooms', 'population', 'households', 'median_income']
procs = [Categorify, FillMissing, Normalize]
to = TabularPandas(df, procs, cat, cont, 'median_house_value', RegressionBlock(),  RandomSplitter()(df), reduce_memory=False).dataloaders(path='.')
xs,y = to.train.xs,to.train.y
val_xs,val_y = to.valid.xs,to.valid.y

RF

from sklearn.ensemble import RandomForestRegressor

m = RandomForestRegressor(100, min_samples_leaf=5).fit(xs, y)
print('MAE:', abs(val_y - m.predict(val_xs)).mean())

x = pd.DataFrame({'cols':xs.columns, 'imp':m.feature_importances_}).sort_values('imp', ascending=False)
x.set_index('cols').plot(kind='barh')
x
MAE: 32150.47857078855
cols imp
9 median_income 0.551278
0 ocean_proximity 0.117905
2 longitude 0.108562
3 latitude 0.106565
4 housing_median_age 0.046819
7 population 0.023806
5 total_rooms 0.017090
6 total_bedrooms 0.014583
8 households 0.013367
1 total_bedrooms_na 0.000025

DT

from sklearn.tree import DecisionTreeRegressor, export_graphviz

m = DecisionTreeRegressor(max_leaf_nodes=30).fit(xs, y)
preds = m.predict(val_xs)
abs(val_y - preds).mean()
48784.69675117493
import graphviz

def draw_tree(t, df, size=10, ratio=0.6, precision=2, **kwargs):
    s=export_graphviz(t, out_file=None, feature_names=df.columns, filled=True, rounded=True,
                      special_characters=True, rotate=False, precision=precision, **kwargs)
    return graphviz.Source(re.sub('Tree {', f'Tree {{ size={size}; ratio={ratio}', s))

draw_tree(m, xs, size=10)
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 median_income ≤ 0.6 squared_error = 13450096298.46 samples = 16512 value = 207451.23 1 median_income ≤ -0.39 squared_error = 8407931988.56 samples = 12979 value = 173549.01 0->1 True 2 median_income ≤ 1.56 squared_error = 12239470987.85 samples = 3533 value = 331996.07 0->2 False 3 ocean_proximity ≤ 1.5 squared_error = 5608496426.91 samples = 6508 value = 136635.1 1->3 4 ocean_proximity ≤ 2.5 squared_error = 8474682591.94 samples = 6471 value = 210673.98 1->4 11 longitude ≤ 0.63 squared_error = 5305068924.25 samples = 2272 value = 172086.5 3->11 12 ocean_proximity ≤ 2.5 squared_error = 4735596941.86 samples = 4236 value = 117620.57 3->12 33 latitude ≤ -0.35 squared_error = 8027839253.15 samples = 838 value = 199719.35 11->33 34 squared_error = 3006960893.71 samples = 1434 value = 155938.43 11->34 41 squared_error = 8932852306.54 samples = 516 value = 224786.46 33->41 42 squared_error = 3957031571.03 samples = 322 value = 159549.69 33->42 13 squared_error = 1618563072.66 samples = 2862 value = 92402.73 12->13 14 longitude ≤ 1.21 squared_error = 7144451228.7 samples = 1374 value = 170148.55 12->14 43 latitude ≤ 1.06 squared_error = 7932045122.76 samples = 1104 value = 181822.48 14->43 44 squared_error = 1088357213.85 samples = 270 value = 122415.19 14->44 45 longitude ≤ -1.37 squared_error = 8053807525.29 samples = 893 value = 197257.01 43->45 46 squared_error = 2141470900.47 samples = 211 value = 116500.0 43->46 47 squared_error = 10674715505.32 samples = 136 value = 283391.22 45->47 48 latitude ≤ 0.91 squared_error = 6010591719.13 samples = 757 value = 181782.44 45->48 53 squared_error = 6044835831.42 samples = 411 value = 206049.16 48->53 54 squared_error = 4439504845.45 samples = 346 value = 152956.94 48->54 7 ocean_proximity ≤ 1.5 squared_error = 6944833238.75 samples = 4829 value = 195279.97 4->7 8 longitude ≤ -1.38 squared_error = 10227315294.42 samples = 1642 value = 255946.63 4->8 9 longitude ≤ 0.63 squared_error = 6603157420.73 samples = 2999 value = 227343.2 7->9 10 squared_error = 3059009268.71 samples = 1830 value = 142734.81 7->10 19 latitude ≤ -0.68 squared_error = 8577116546.33 samples = 1371 value = 255556.79 9->19 20 squared_error = 3705941264.15 samples = 1628 value = 203583.49 9->20 21 longitude ≤ 0.6 squared_error = 10906702197.27 samples = 485 value = 315459.68 19->21 22 squared_error = 4262355206.19 samples = 886 value = 222765.7 19->22 31 squared_error = 7005312830.96 samples = 223 value = 366057.56 21->31 32 latitude ≤ -0.74 squared_error = 10193604375.28 samples = 262 value = 272393.55 21->32 35 squared_error = 5039079358.57 samples = 169 value = 228079.88 32->35 36 squared_error = 9507385666.05 samples = 93 value = 352920.55 32->36 17 housing_median_age ≤ 1.49 squared_error = 10110054299.53 samples = 456 value = 316712.38 8->17 18 longitude ≤ 1.21 squared_error = 8306840619.78 samples = 1186 value = 232583.07 8->18 39 squared_error = 7570491263.42 samples = 249 value = 276532.15 17->39 40 latitude ≤ 0.99 squared_error = 8886817723.63 samples = 207 value = 365045.11 17->40 51 squared_error = 4345060048.08 samples = 110 value = 312891.85 40->51 52 squared_error = 7454890227.25 samples = 97 value = 424187.99 40->52 27 latitude ≤ 0.91 squared_error = 8177004850.63 samples = 1041 value = 244176.97 18->27 28 squared_error = 1345679042.09 samples = 145 value = 149346.9 18->28 29 squared_error = 9130254958.69 samples = 633 value = 276095.3 27->29 30 squared_error = 2665192995.42 samples = 408 value = 194656.62 27->30 5 housing_median_age ≤ -0.1 squared_error = 8946122397.82 samples = 2483 value = 292783.99 2->5 6 median_income ≤ 2.05 squared_error = 7793088915.86 samples = 1050 value = 424723.3 2->6 15 median_income ≤ 0.95 squared_error = 6509956975.94 samples = 1472 value = 266567.96 5->15 16 median_income ≤ 1.03 squared_error = 10035513797.35 samples = 1011 value = 330954.11 5->16 37 squared_error = 5814008906.84 samples = 712 value = 240696.23 15->37 38 squared_error = 5947410362.83 samples = 760 value = 290805.7 15->38 25 squared_error = 9187757434.12 samples = 642 value = 304493.84 16->25 26 squared_error = 8172966970.11 samples = 369 value = 376990.67 16->26 23 housing_median_age ≤ -0.17 squared_error = 7769016670.28 samples = 408 value = 373146.5 6->23 24 households ≤ -3.67 squared_error = 5043430957.38 samples = 642 value = 457501.09 6->24 49 squared_error = 5688658380.81 samples = 244 value = 339658.68 23->49 50 squared_error = 6713344468.47 samples = 164 value = 422969.84 23->50 55 squared_error = 17028586432.78 samples = 14 value = 271057.29 24->55 56 median_income ≤ 2.7 squared_error = 3984037152.54 samples = 628 value = 461657.48 24->56 57 squared_error = 5361175172.58 samples = 310 value = 432276.53 56->57 58 squared_error = 979667506.0 samples = 318 value = 490299.28 56->58

DNN

learn = tabular_learner(to, metrics=L1LossFlat(), layers=[10,10])
learn.lr_find(suggest_funcs=(slide, valley))
SuggestedLRs(slide=6.309573450380412e-07, valley=0.17378008365631104)
learn.fit(10, lr=0.1)
learn.recorder.plot_loss()
epoch train_loss valid_loss None time
0 52540518400.000000 49501884416.000000 195610.734375 00:02
1 41421684736.000000 36557541376.000000 169839.937500 00:02
2 30079684608.000000 28439451648.000000 151214.703125 00:02
3 20997189632.000000 15527985152.000000 108336.179688 00:02
4 14592946176.000000 18044657664.000000 120398.703125 00:02
5 10348133376.000000 7776253952.000000 70381.062500 00:02
6 7479977984.000000 6708659200.000000 64639.753906 00:02
7 5428857344.000000 4700456960.000000 49034.535156 00:02
8 4591148032.000000 3880753408.000000 42761.687500 00:02
9 4247306496.000000 3972862464.000000 43510.292969 00:02