--- title: Data Tabular keywords: fastai sidebar: home_sidebar summary: "Main Tabular functions used throughout the library. This is helpful when you have additional time series data like metadata, time series features, etc." description: "Main Tabular functions used throughout the library. This is helpful when you have additional time series data like metadata, time series features, etc." nb_path: "nbs/021_data.tabular.ipynb" ---
get_tabular_ds
[source]
get_tabular_ds
(df
,procs
=[<class 'fastai.tabular.core.Categorify'>, <class 'fastai.tabular.core.FillMissing'>, <class 'fastai.data.transforms.Normalize'>]
,cat_names
=None
,cont_names
=None
,y_names
=None
,groupby
=None
,y_block
=None
,splits
=None
,do_setup
=True
,inplace
=False
,reduce_memory
=True
,device
=None
)
get_tabular_dls
[source]
get_tabular_dls
(df
,procs
=[<class 'fastai.tabular.core.Categorify'>, <class 'fastai.tabular.core.FillMissing'>, <class 'fastai.data.transforms.Normalize'>]
,cat_names
=None
,cont_names
=None
,y_names
=None
,bs
=64
,y_block
=None
,splits
=None
,do_setup
=True
,inplace
=False
,reduce_memory
=True
,device
=None
,path
='.'
)
preprocess_df
[source]
preprocess_df
(df
,procs
=[<class 'fastai.tabular.core.Categorify'>, <class 'fastai.tabular.core.FillMissing'>, <class 'fastai.data.transforms.Normalize'>]
,cat_names
=None
,cont_names
=None
,y_names
=None
,sample_col
=None
,reduce_memory
=True
)
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
cat_names = ['workclass', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
'capital-gain', 'capital-loss', 'native-country']
cont_names = ['age', 'fnlwgt', 'hours-per-week']
target = ['salary']
splits = RandomSplitter()(range_of(df))
dls = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names='salary', splits=splits, bs=512)
dls.show_batch()
workclass | education | education-num | marital-status | occupation | relationship | race | sex | capital-gain | capital-loss | native-country | age | fnlwgt | hours-per-week | salary | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Private | HS-grad | 9 | Married-civ-spouse | Tech-support | Husband | White | Male | 0 | 0 | United-States | 51.000000 | 133336.001891 | 40.000000 | >=50k |
1 | Self-emp-not-inc | HS-grad | 9 | Married-civ-spouse | Craft-repair | Husband | White | Male | 0 | 0 | United-States | 35.000000 | 278632.003101 | 40.000000 | <50k |
2 | Private | Assoc-voc | 11 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | United-States | 54.000001 | 103344.999857 | 40.000000 | >=50k |
3 | Local-gov | Bachelors | 13 | Separated | Prof-specialty | Not-in-family | White | Female | 0 | 0 | United-States | 47.000000 | 93475.998324 | 69.999999 | <50k |
4 | Private | Bachelors | 13 | Never-married | Prof-specialty | Not-in-family | White | Female | 0 | 0 | Cuba | 34.000000 | 351810.005223 | 45.000000 | <50k |
5 | ? | Some-college | 10 | Widowed | ? | Unmarried | Black | Female | 0 | 4356 | United-States | 65.999999 | 186060.999965 | 40.000000 | <50k |
6 | State-gov | HS-grad | 9 | Never-married | Transport-moving | Own-child | White | Male | 0 | 0 | United-States | 20.000000 | 200819.000080 | 40.000000 | <50k |
7 | Private | 11th | 7 | Never-married | Adm-clerical | Own-child | White | Female | 0 | 0 | United-States | 18.000000 | 41972.997603 | 5.000000 | <50k |
8 | Private | HS-grad | 9 | Divorced | Sales | Not-in-family | White | Female | 0 | 0 | United-States | 60.000001 | 227265.999817 | 33.000000 | <50k |
9 | Local-gov | HS-grad | 9 | Married-civ-spouse | Transport-moving | Husband | White | Male | 0 | 0 | United-States | 46.000000 | 172821.999572 | 40.000000 | <50k |
metrics = mae if dls.c == 1 else accuracy
learn = tabular_learner(dls, layers=[200, 100], y_range=None, metrics=metrics)
learn.fit(1, 1e-2)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.348326 | 0.286412 | 0.868090 | 00:04 |
learn.dls.one_batch()
(tensor([[ 5, 16, 10, ..., 1, 1, 40], [ 5, 12, 9, ..., 1, 1, 40], [ 5, 12, 9, ..., 1, 1, 40], ..., [ 5, 12, 9, ..., 1, 1, 40], [ 6, 12, 9, ..., 1, 1, 40], [ 5, 7, 5, ..., 1, 1, 40]]), tensor([[ 1.7120, -1.1040, -0.0356], [-0.4838, 0.1435, -0.0356], [ 1.6388, -0.2557, -0.0356], ..., [ 1.4192, 0.2211, -0.0356], [ 1.3460, 1.7082, -0.0356], [-1.3621, -0.6940, -0.0356]]), tensor([[0], [0], [0], [0], [0], [0], [0], [1], [1], [0], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0], [1], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [1], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [1], [0], [1], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [0], [1], [0], [1], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [1], [0], [0], [0], [1], [1], [0], [0], [0], [0], [0], [1], [0], [0], [1], [0], [0], [1], [0], [0], [1], [0], [0], [1], [0], [1], [0], [0], [1], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [0], [1], [0], [0], [0], [0], [0], [1], [1], [0], [0], [1], [0], [1], [1], [0], [0], [1], [0], [0], [1], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [1], [1], [1], [0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [1], [0], [0], [0], [1], [0], [0], [0], [0], [1], [0], [0], [1], [1], [0], [0], [0], [0], [0], [1], [0], [1], [1], [0], [0], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [1], [0], [0], [0], [0], [0], [0], [1], [0], [1], [0], [1], [0], [0], [0], [0], [0], [0], [1], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [0], [1], [0], [0], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0], [1], [0], [0], [0], [1], [0], [1], [0], [0], [0], [0], [1], [1], [0], [0], [1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [1], [0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [0], [1], [0], [0], [0], [0], [0], [0], [1], [0], [0], [0], [1], [0], [0], [0], [0], [1], [0], [1], [1], [1], [0], [0], [0], [1], [0], [1], [0], [0], [0], [0], [0], [0], [1], [1], [0], [1], [0], [0], [0], [0], [0], [0], [0]], dtype=torch.int8))
learn.model
TabularModel( (embeds): ModuleList( (0): Embedding(10, 6) (1): Embedding(17, 8) (2): Embedding(17, 8) (3): Embedding(8, 5) (4): Embedding(16, 8) (5): Embedding(7, 5) (6): Embedding(6, 4) (7): Embedding(3, 3) (8): Embedding(117, 23) (9): Embedding(90, 20) (10): Embedding(43, 13) ) (emb_drop): Dropout(p=0.0, inplace=False) (bn_cont): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (layers): Sequential( (0): LinBnDrop( (0): Linear(in_features=106, out_features=200, bias=False) (1): ReLU(inplace=True) (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): LinBnDrop( (0): Linear(in_features=200, out_features=100, bias=False) (1): ReLU(inplace=True) (2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): LinBnDrop( (0): Linear(in_features=100, out_features=2, bias=True) ) ) )
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
cat_names = ['workclass', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
'capital-gain', 'capital-loss', 'native-country']
cont_names = ['age', 'fnlwgt', 'hours-per-week']
target = ['salary']
df, procs = preprocess_df(df, procs=[Categorify, FillMissing, Normalize], cat_names=cat_names, cont_names=cont_names, y_names=target,
sample_col=None, reduce_memory=True)
df.head()
workclass | education | education-num | marital-status | occupation | relationship | race | sex | capital-gain | capital-loss | native-country | age | fnlwgt | hours-per-week | salary | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 5 | 8 | 12 | 3 | 0 | 6 | 5 | 1 | 1 | 48 | 40 | 0.763796 | -0.838084 | -0.035429 | 1 |
1 | 5 | 13 | 14 | 1 | 5 | 2 | 5 | 2 | 101 | 1 | 40 | 0.397233 | 0.444987 | 0.369519 | 1 |
2 | 5 | 12 | 0 | 1 | 0 | 5 | 3 | 1 | 1 | 1 | 40 | -0.042642 | -0.886734 | -0.683348 | 0 |
3 | 6 | 15 | 15 | 3 | 11 | 1 | 2 | 2 | 1 | 1 | 40 | -0.042642 | -0.728873 | -0.035429 | 1 |
4 | 7 | 6 | 0 | 3 | 9 | 6 | 3 | 1 | 1 | 1 | 40 | 0.250608 | -1.018314 | 0.774468 | 0 |
procs.classes, procs.means, procs.stds
({'workclass': ['#na#', ' ?', ' Federal-gov', ' Local-gov', ' Never-worked', ' Private', ' Self-emp-inc', ' Self-emp-not-inc', ' State-gov', ' Without-pay'], 'education': ['#na#', ' 10th', ' 11th', ' 12th', ' 1st-4th', ' 5th-6th', ' 7th-8th', ' 9th', ' Assoc-acdm', ' Assoc-voc', ' Bachelors', ' Doctorate', ' HS-grad', ' Masters', ' Preschool', ' Prof-school', ' Some-college'], 'education-num': ['#na#', 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0], 'marital-status': ['#na#', ' Divorced', ' Married-AF-spouse', ' Married-civ-spouse', ' Married-spouse-absent', ' Never-married', ' Separated', ' Widowed'], 'occupation': ['#na#', ' ?', ' Adm-clerical', ' Armed-Forces', ' Craft-repair', ' Exec-managerial', ' Farming-fishing', ' Handlers-cleaners', ' Machine-op-inspct', ' Other-service', ' Priv-house-serv', ' Prof-specialty', ' Protective-serv', ' Sales', ' Tech-support', ' Transport-moving'], 'relationship': ['#na#', ' Husband', ' Not-in-family', ' Other-relative', ' Own-child', ' Unmarried', ' Wife'], 'race': ['#na#', ' Amer-Indian-Eskimo', ' Asian-Pac-Islander', ' Black', ' Other', ' White'], 'sex': ['#na#', ' Female', ' Male'], 'capital-gain': ['#na#', 0, 114, 401, 594, 914, 991, 1055, 1086, 1111, 1151, 1173, 1409, 1424, 1455, 1471, 1506, 1639, 1797, 1831, 1848, 2009, 2036, 2050, 2062, 2105, 2174, 2176, 2202, 2228, 2290, 2329, 2346, 2354, 2387, 2407, 2414, 2463, 2538, 2580, 2597, 2635, 2653, 2829, 2885, 2907, 2936, 2961, 2964, 2977, 2993, 3103, 3137, 3273, 3325, 3411, 3418, 3432, 3456, 3464, 3471, 3674, 3781, 3818, 3887, 3908, 3942, 4064, 4101, 4386, 4416, 4508, 4650, 4687, 4787, 4865, 4931, 4934, 5013, 5060, 5178, 5455, 5556, 5721, 6097, 6360, 6418, 6497, 6514, 6723, 6767, 6849, 7298, 7430, 7443, 7688, 7896, 7978, 8614, 9386, 9562, 10520, 10566, 10605, 11678, 13550, 14084, 14344, 15020, 15024, 15831, 18481, 20051, 22040, 25124, 25236, 27828, 34095, 41310, 99999], 'capital-loss': ['#na#', 0, 155, 213, 323, 419, 625, 653, 810, 880, 974, 1092, 1138, 1258, 1340, 1380, 1408, 1411, 1485, 1504, 1539, 1564, 1573, 1579, 1590, 1594, 1602, 1617, 1628, 1648, 1651, 1668, 1669, 1672, 1719, 1721, 1726, 1735, 1740, 1741, 1755, 1762, 1816, 1825, 1844, 1848, 1876, 1887, 1902, 1944, 1974, 1977, 1980, 2001, 2002, 2042, 2051, 2057, 2080, 2129, 2149, 2163, 2174, 2179, 2201, 2205, 2206, 2231, 2238, 2246, 2258, 2267, 2282, 2339, 2352, 2377, 2392, 2415, 2444, 2457, 2467, 2472, 2489, 2547, 2559, 2603, 2754, 2824, 3004, 3683, 3770, 3900, 4356], 'native-country': ['#na#', ' ?', ' Cambodia', ' Canada', ' China', ' Columbia', ' Cuba', ' Dominican-Republic', ' Ecuador', ' El-Salvador', ' England', ' France', ' Germany', ' Greece', ' Guatemala', ' Haiti', ' Holand-Netherlands', ' Honduras', ' Hong', ' Hungary', ' India', ' Iran', ' Ireland', ' Italy', ' Jamaica', ' Japan', ' Laos', ' Mexico', ' Nicaragua', ' Outlying-US(Guam-USVI-etc)', ' Peru', ' Philippines', ' Poland', ' Portugal', ' Puerto-Rico', ' Scotland', ' South', ' Taiwan', ' Thailand', ' Trinadad&Tobago', ' United-States', ' Vietnam', ' Yugoslavia']}, {'age': 38.58164675532078, 'fnlwgt': 189778.36651208502, 'hours-per-week': 40.437455852092995}, {'age': 13.640223192304274, 'fnlwgt': 105548.3568809908, 'hours-per-week': 12.347239175707989})