--- title: Gradient Blending keywords: fastai sidebar: home_sidebar summary: "Callback used to apply gradient blending to multi-modal models." description: "Callback used to apply gradient blending to multi-modal models." nb_path: "nbs/062_callback.gblend.ipynb" ---
This is an unofficial PyTorch implementation by Ignacio Oguiza (timeseriesAI@gmail.com) based on: Wang, W., Tran, D., & Feiszli, M. (2020). What Makes Training Multi-Modal Classification Networks Hard?. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 12695-12705).
from fastai.data.transforms import *
from tsai.data.all import *
from tsai.models.utils import *
from tsai.models.XCM import *
from tsai.models.TabModel import *
from tsai.models.MultiInputNet import *
from tsai.learner import *
dsid = 'NATOPS'
X, y, splits = get_UCR_data(dsid, split_data=False)
ts_features_df = get_ts_features(X, y)
tfms = [None, [Categorize()]]
batch_tfms = TSStandardize()
ts_dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms)
ts_model = build_ts_model(XCM, dls=ts_dls, window_perc=.5)
# ts features
cat_names = None
cont_names = ts_features_df.columns[:-2]
y_names = 'target'
tab_dls = get_tabular_dls(ts_features_df, cat_names=cat_names, cont_names=cont_names, y_names=y_names, splits=splits)
tab_model = build_tabular_model(TabModel, dls=tab_dls)
# mixed
mixed_dls = get_mixed_dls(ts_dls, tab_dls)
MultiModalNet = MultiInputNet(ts_model, tab_model, c_out=mixed_dls.c)
gblend = GBlend(V_pct=.5, n=(10, 5), sel_metric=None)
learn = Learner(mixed_dls, MultiModalNet, metrics=[accuracy, RocAuc()], cbs=gblend)
learn.fit_one_cycle(1, 1e-3)