--- title: Layers keywords: fastai sidebar: home_sidebar summary: "Helper function used to build PyTorch timeseries models." description: "Helper function used to build PyTorch timeseries models." nb_path: "nbs/100_models.layers.ipynb" ---
bs = 2
c_in = 3
c_out = 5
h = 16
w = 20
t = torch.rand(bs, c_in, h, w)
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))
test_eq(Conv2dSame(c_in, c_out, ks=(3, 1), stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(1, 1), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h, w))
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(1, 1), bias=False)(t).shape, (bs, c_out, h//2, w//2))
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h//2, w//2))
test_eq(Conv2d(c_in, c_out, ks=3, padding='same', stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))
bs = 2
c_in = 3
c_out = 5
seq_len = 512
t = torch.rand(bs, c_in, seq_len)
dilation = 1
test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding="same", dilation=dilation)(t).shape)
dilation = 2
test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding="same", dilation=dilation)(t).shape)
bs = 2
ni = 3
nf = 5
seq_len = 6
ks = 3
t = torch.rand(bs, c_in, seq_len)
test_eq(Conv1d(ni, nf, ks, padding=0)(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))
test_eq(Conv1d(ni, nf, ks, padding='valid')(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))
test_eq(Conv1d(ni, nf, ks, padding='same')(t).shape, (bs, c_out, seq_len))
test_eq(Conv1d(ni, nf, ks, padding='causal')(t).shape, (bs, c_out, seq_len))
test_error('use kernel_size or ks but not both simultaneously', Conv1d, ni, nf, kernel_size=3, ks=3)
test_error('you need to pass a ks', Conv1d, ni, nf)
conv = Conv1d(ni, nf, ks, padding='same')
init_linear(conv, None, init='auto', bias_std=.01)
conv
conv = Conv1d(ni, nf, ks, padding='causal')
init_linear(conv, None, init='auto', bias_std=.01)
conv
conv = Conv1d(ni, nf, ks, padding='valid')
init_linear(conv, None, init='auto', bias_std=.01)
weight_norm(conv)
conv
conv = Conv1d(ni, nf, ks, padding=0)
init_linear(conv, None, init='auto', bias_std=.01)
weight_norm(conv)
conv
bs = 64
c_in = 6
c_out = 5
seq_len = 512
t = torch.rand(bs, c_in, seq_len)
test_eq(SeparableConv1d(c_in, c_out, 3)(t).shape, (bs, c_out, seq_len))
bs = 2
c_in = 3
c_out = 5
seq_len = 50
t = torch.rand(bs, c_in, seq_len)
t = (t - t.mean()) / t.std()
test_eq(AddCoords1d()(t).shape, (bs, c_in + 1, seq_len))
new_t = AddCoords1d()(t)
test_close(new_t.mean(),0, 1e-2)
test_close(new_t.std(), 1, 1e-2)
t = torch.rand(8, 32, 12)
test_eq(SEModule1d(t.shape[1], 16, act=nn.ReLU, act_kwargs={})(t).shape, t.shape)
bs = 2
ni = 3
nf = 5
sl = 4
ks = 5
t = torch.rand(bs, ni, sl)
test_eq(ConvBlock(ni, nf, ks)(t).shape, (bs, nf, sl))
test_eq(ConvBlock(ni, nf, ks, padding='causal')(t).shape, (bs, nf, sl))
test_eq(ConvBlock(ni, nf, ks, coord=True)(t).shape, (bs, nf, sl))
test_eq(BN1d(ni)(t).shape, (bs, ni, sl))
test_eq(BN1d(ni).weight.data.mean().item(), 1.)
test_eq(BN1d(ni, zero_norm=True).weight.data.mean().item(), 0.)
test_eq(ConvBlock(ni, nf, ks, norm='batch', zero_norm=True)[1].weight.data.unique().item(), 0)
test_ne(ConvBlock(ni, nf, ks, norm='batch', zero_norm=False)[1].weight.data.unique().item(), 0)
test_eq(ConvBlock(ni, nf, ks, bias=False)[0].bias, None)
ConvBlock(ni, nf, ks, act=Swish, coord=True)
LinLnDrop(2, 3, p=.5)
bs = 2
nf = 5
sl = 4
t = torch.rand(bs, nf, sl)
test_eq(Permute(0,2,1)(t).shape, (bs, sl, nf))
test_eq(Max(1)(t).shape, (bs, sl))
test_eq(Transpose(1,2)(t).shape, (bs, sl, nf))
test_eq(Transpose(1,2, contiguous=True)(t).shape, (bs, sl, nf))
test_eq(View(-1, 2, 10)(t).shape, (bs, 1, 2, 10))
test_eq(Reshape(-1, 2, 10)(t).shape, (bs, 1, 2, 10))
Transpose(1,2), Permute(0,2,1), View(-1, 2, 10), Transpose(1,2, contiguous=True), Reshape(-1, 2, 10), Noop
t = torch.ones(100,2,3)
test_eq(DropPath(0.)(t), t)
assert DropPath(0.5)(t).max() >= 1
n_samples = 1000
n_classes = 3
t = (torch.rand(n_samples, n_classes) - .5) * 10
probas = F.softmax(t, -1)
sharpened_probas = Sharpen()(probas)
plt.plot(probas.flatten().sort().values, color='r')
plt.plot(sharpened_probas.flatten().sort().values, color='b')
plt.show()
test_gt(sharpened_probas[n_samples//2:].max(-1).values.sum().item(), probas[n_samples//2:].max(-1).values.sum().item())
bs = 2
c_out = 3
t = torch.rand(bs, c_out)
for calibrator, cal_name in zip(['temp', 'vector', 'matrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']):
cal = get_calibrator(calibrator, n_classes=c_out)
# print(calibrator)
# print(cal.weight, cal.bias, '\n')
test_eq(cal(t), t)
test_eq(cal.__class__.__name__, cal_name)
for calibrator, cal_name in zip(['dtemp', 'dvector', 'dmatrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']):
cal = get_calibrator(calibrator, n_classes=c_out)
# print(calibrator)
# print(cal.weight, cal.bias, '\n')
test_eq(cal(t), F.log_softmax(t, dim=1))
test_eq(cal.__class__.__name__, cal_name)
bs = 2
c_out = 3
t = torch.rand(bs, c_out)
test_eq(Temp_Scale()(t).shape, t.shape)
test_eq(Vector_Scale(c_out)(t).shape, t.shape)
test_eq(Matrix_Scale(c_out)(t).shape, t.shape)
test_eq(Temp_Scale(dirichlet=True)(t).shape, t.shape)
test_eq(Vector_Scale(c_out, dirichlet=True)(t).shape, t.shape)
test_eq(Matrix_Scale(c_out, dirichlet=True)(t).shape, t.shape)
test_eq(Temp_Scale()(t), t)
test_eq(Vector_Scale(c_out)(t), t)
test_eq(Matrix_Scale(c_out)(t), t)
bs = 2
c_out = 5
t = torch.rand(bs, c_out)
test_eq(Vector_Scale(c_out)(t), t)
test_eq(Vector_Scale(c_out).weight.data, torch.ones(c_out))
test_eq(Vector_Scale(c_out).weight.requires_grad, True)
test_eq(type(Vector_Scale(c_out).weight), torch.nn.parameter.Parameter)
bs = 2
c_out = 3
weight = 2
bias = 1
t = torch.rand(bs, c_out)
test_eq(Matrix_Scale(c_out)(t).shape, t.shape)
test_eq(Matrix_Scale(c_out).weight.requires_grad, True)
test_eq(type(Matrix_Scale(c_out).weight), torch.nn.parameter.Parameter)
bs, n_classes = 16, 3
class_priors = torch.rand(n_classes)
logits = torch.randn(bs, n_classes) * 2
test_eq(LogitAdjLayer(class_priors)(logits), logits + class_priors)
bs = 2
nf = 5
sl = 4
t = torch.rand(bs, nf, sl)
test_eq(MaxPPVPool1d()(t).shape, (bs, nf*2, 1))
test_eq(MaxPPVPool1d()(t).shape, AdaptiveConcatPool1d(1)(t).shape)
t = torch.randn(16, 64, 50)
head = gwa_pool_head(64, 5, 50)
test_eq(head(t).shape, (16, 5))
bs, c_in, seq_len = 16, 1, 50
c_out = 3
t = torch.rand(bs, c_in, seq_len)
test_eq(GAP1d()(t).shape, (bs, c_in))
test_eq(GACP1d()(t).shape, (bs, c_in*2))
bs, c_in, seq_len = 16, 4, 50
t = torch.rand(bs, c_in, seq_len)
test_eq(GAP1d()(t).shape, (bs, c_in))
test_eq(GACP1d()(t).shape, (bs, c_in*2))
test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))
test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))
test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=False)(t).shape, (bs, c_in))
test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=True)(t).shape, (bs, c_in))
test_eq(AttentionalPool1d(c_in, c_out)(t).shape, (bs, c_out, 1))
bs, c_in, seq_len = 16, 128, 50
c_out = 14
t = torch.rand(bs, c_in, seq_len)
attp = attentional_pool_head(c_in, c_out)
test_eq(attp(t).shape, (bs, c_out))
test_eq(get_act_fn(nn.ReLU).__repr__(), "ReLU()")
test_eq(get_act_fn(nn.ReLU()).__repr__(), "ReLU()")
test_eq(get_act_fn(nn.LeakyReLU, negative_slope=0.05).__repr__(), "LeakyReLU(negative_slope=0.05)")
test_eq(get_act_fn('reglu').__repr__(), "ReGLU()")
test_eq(get_act_fn('leakyrelu', negative_slope=0.05).__repr__(), "LeakyReLU(negative_slope=0.05)")
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
test_eq(create_pool_head(nf, c_out, seq_len, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))
create_pool_head(nf, c_out, seq_len, concat_pool=True, bn=True, fc_dropout=.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(max_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
test_eq(create_pool_plus_head(nf, c_out, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))
create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_conv_head(nf, c_out, seq_len)(t).shape, (bs, c_out))
test_eq(create_conv_head(nf, c_out, adaptive_size=50)(t).shape, (bs, c_out))
create_conv_head(nf, c_out, 50)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_mlp_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
t = torch.rand(bs, nf, seq_len)
create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_fc_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_rnn_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
create_rnn_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)
bs = 16
nf = 12
ni = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=None, fc_dropout=0.)
test_eq(head(t).shape, (bs, ni, seq_len))
head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=(.3,.7), fc_dropout=0.)
test_ge(head(t).min(), .3)
test_le(head(t).max(), .7)
y_range = (tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000, 0.2000, 0.3000,
0.3000, 0.3000, 0.3000]),
tensor([0.6000, 0.6000, 0.6000, 0.6000, 0.7000, 0.7000, 0.7000, 0.7000, 0.8000,
0.8000, 0.8000, 0.8000]))
test_ge(head(t).min(), .1)
test_le(head(t).max(), .9)
head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=y_range, fc_dropout=0.)
head
bs = 16
nf = 32
c = 5
seq_len = 10
d = 2
targ = torch.randint(0, c, (bs,d))
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=True, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d, c))
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = [2, 8]
targ = torch.randint(0, c, [bs]+d)
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d+[c])
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = 2
targ = torch.rand(bs, d)
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d))
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = [2,3]
targ = torch.rand(bs, *d)
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d)
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = 2
targ = torch.randint(0, c, (bs,d))
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d, c))
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = [2, 8]
targ = torch.randint(0, c, [bs]+d)
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d+[c])
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = 2
targ = torch.rand(bs, d)
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d))
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = [2,3]
targ = torch.rand(bs, *d)
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d)
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = 10
targ = torch.randint(0, c, (bs,d))
t = torch.randn(bs, nf, seq_len)
head = conv_3d_head(nf, c, seq_len, d)
inp = head(t)
test_eq(inp.shape, (bs, d, c))
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = 10
targ = torch.rand(bs, d)
t = torch.randn(bs, nf, seq_len)
head = conv_3d_head(nf, c, seq_len, d)
inp = head(t)
test_eq(inp.shape, (bs, d))
loss = L1LossFlat()(inp, targ)
loss, head
bs, c_in, seq_len = 16, 128, 50
c_out = 14
t = torch.rand(bs, c_in, seq_len)
uph = universal_pool_head(c_in, c_out, seq_len)
test_eq(uph(t).shape, (bs, c_out))
uph = universal_pool_head(c_in, c_out, seq_len, 2)
test_eq(uph(t).shape, (bs, c_out))
bs, c_in, seq_len = 16, 128, 50
c_out = 14
d = 5
t = torch.rand(bs, c_in, seq_len)
for head in heads:
print(head.__name__)
if head.__name__ == "create_conv_3d_head":
h = head(c_in, c_out, seq_len, seq_len)
test_eq(h(t).shape, (bs, seq_len, c_out))
elif 'nd' in head.__name__:
h = head(c_in, c_out, seq_len, d)
test_eq(h(t).shape, (bs, d, c_out))
else:
h = head(c_in, c_out, seq_len)
test_eq(h(t).shape, (bs, c_out))
bs = 2
ni = 32
sl = 4
t = torch.rand(bs, ni, sl)
test_eq(SqueezeExciteBlock(ni)(t).shape, (bs, ni, sl))
t = torch.ones(2,3,4)
test_ne(GaussianNoise()(t), t)
test_eq(GaussianNoise()(t).shape, t.shape)
t = torch.ones(2,3)
test_ne(GaussianNoise()(t), t)
test_eq(GaussianNoise()(t).shape, t.shape)
t = torch.ones(2)
test_ne(GaussianNoise()(t), t)
test_eq(GaussianNoise()(t).shape, t.shape)
model_output = torch.rand(16, 3)
targets = torch.randint(0, 2, (16,))
criterion = gambler_loss(2)
criterion(model_output, targets)
output = torch.rand(16, 2)
target = torch.randint(0, 2, (16,))
CrossEntropyLossOneHot(output, target)
from tsai.data.transforms import OneHot
output = nn.Parameter(torch.rand(16, 2))
target = torch.randint(0, 2, (16,))
one_hot_target = OneHot()(target)
CrossEntropyLossOneHot(output, one_hot_target)
ttest_tensor(a, b)
for _ in range(100):
output = torch.rand(256, 2)
target = torch.randint(0, 2, (256,))
test_close(ttest_bin_loss(output, target).item(),
ttest_ind(nn.Softmax(dim=-1)(output[:, 1])[target == 0], nn.Softmax(dim=-1)(output[:, 1])[target == 1], equal_var=False)[0], eps=1e-3)
c_in = 10
x = torch.rand(64, c_in).to(device=default_device())
x = F.softmax(x, dim=1)
label = x.max(dim=1).indices
CenterLoss(c_in).to(x.device)(x, label), CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in).to(x.device)(x, label)
CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in)
c_in = 10
x = torch.rand(64, c_in).to(device=default_device())
x = F.softmax(x, dim=1)
label = x.max(dim=1).indices
FocalLoss(c_in).to(x.device)(x, label)
c_in = 10
output = torch.rand(64).to(device=default_device())
target = torch.rand(64).to(device=default_device())
TweedieLoss().to(output.device)(output, target)
t = torch.randn(2,3,10)
m = PositionwiseFeedForward(10, dropout=0., act='reglu', mlp_ratio=1)
test_eq(m(t).shape, t.shape)
B = 16
C = 10
M = 1500 # seq_len
n_heads = 1
D = 128 # model dimension
N = 512 # max_seq_len - latent's index dimension
d_k = D // n_heads
xb = torch.randn(B, C, M)
xb = (xb - xb.mean()) / xb.std()
# Attention
# input (Q)
lin = nn.Linear(M, N, bias=False)
Q = lin(xb).transpose(1,2)
test_eq(Q.shape, (B, N, C))
# q
to_q = nn.Linear(C, D, bias=False)
q = to_q(Q)
q = nn.LayerNorm(D)(q)
# k, v
context = xb.transpose(1,2)
to_kv = nn.Linear(C, D * 2, bias=False)
k, v = to_kv(context).chunk(2, dim = -1)
k = k.transpose(-1, -2)
k = nn.LayerNorm(M)(k)
v = nn.LayerNorm(D)(v)
test_eq(q.shape, (B, N, D))
test_eq(k.shape, (B, D, M))
test_eq(v.shape, (B, M, D))
output, attn, scores = ScaledDotProductAttention(res_attention=True)(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
test_eq(output.shape, (B, 1, N, D))
test_eq(attn.shape, (B, 1, N, M))
test_eq(scores.shape, (B, 1, N, M))
scores.mean(), scores.std()
# class MultiheadAttention(Module):
# def __init__(self, d_model:int, n_heads:int, d_k:Optional[int]=None, d_v:Optional[int]=None, res_attention:bool=False,
# dropout:float=0., qkv_bias:bool=True):
# """Multi Head Attention Layer
# Input shape:
# Q: [batch_size (bs) x max_q_len x d_model]
# K, V: [batch_size (bs) x q_len x d_model]
# mask: [q_len x q_len]
# """
# d_k = ifnone(d_k, d_model // n_heads)
# d_v = ifnone(d_v, d_model // n_heads)
# self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
# self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
# self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
# self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# # Scaled Dot-Product Attention (multiple heads)
# self.res_attention = res_attention
# self.sdp_attn = ScaledDotProductAttention(res_attention=self.res_attention)
# # Poject output
# project_out = not (n_heads == 1 and d_model == d_k)
# self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(dropout)) if project_out else nn.Identity()
# def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
# key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
# bs = Q.size(0)
# if K is None: K = Q
# if V is None: V = Q
# # Linear (+ split in multiple heads)
# q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
# k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
# v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
# # Apply Scaled Dot-Product Attention (multiple heads)
# if self.res_attention:
# output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# else:
# output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
# # back to the original inputs dimensions
# output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
# output = self.to_out(output)
# if self.res_attention: return output, attn_weights, attn_scores
# else: return output, attn_weights
q = torch.rand([16, 3, 50, 8])
k = torch.rand([16, 3, 50, 8]).transpose(-1, -2)
v = torch.rand([16, 3, 50, 6])
attn_mask = torch.triu(torch.ones(50, 50)) # shape: q_len x q_len
key_padding_mask = torch.zeros(16, 50)
key_padding_mask[[1, 3, 6, 15], -10:] = 1
key_padding_mask = key_padding_mask.bool()
print('attn_mask', attn_mask.shape, 'key_padding_mask', key_padding_mask.shape)
output, attn = ScaledDotProductAttention(attn_dropout=.1)(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
output.shape, attn.shape
t = torch.rand(16, 50, 128)
output, attn = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
output.shape, attn.shape
t = torch.rand(16, 50, 128)
att_mask = (torch.rand((50, 50)) > .85).float()
att_mask[att_mask == 1] = -np.inf
mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)
output, attn = mha(t, t, t, attn_mask=att_mask)
test_eq(torch.isnan(output).sum().item(), 0)
test_eq(torch.isnan(attn).sum().item(), 0)
loss = output[:2, :].sum()
test_eq(torch.isnan(loss).sum().item(), 0)
loss.backward()
for n, p in mha.named_parameters(): test_eq(torch.isnan(p.grad).sum().item(), 0)
t = torch.rand(16, 50, 128)
attn_mask = (torch.rand((50, 50)) > .85)
# True values will be masked
mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)
output, attn = mha(t, t, t, attn_mask=att_mask)
test_eq(torch.isnan(output).sum().item(), 0)
test_eq(torch.isnan(attn).sum().item(), 0)
loss = output[:2, :].sum()
test_eq(torch.isnan(loss).sum().item(), 0)
loss.backward()
for n, p in mha.named_parameters(): test_eq(torch.isnan(p.grad).sum().item(), 0)
t = torch.rand(16, 6, 37)
test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True)(t).shape, [16, 24, 37])
test_eq(MultiConv1d(6, 36, kss=[1,3,5], keep_original=False)(t).shape, [16, 36, 37])
test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True, dim=-1)(t).shape, [16, 6, 37*4])
test_eq(MultiConv1d(6, 60, kss=[1,3,5], keep_original=True)(t).shape, [16, 60, 37])
test_eq(MultiConv1d(6, 60, kss=[1,3,5], separable=True)(t).shape, [16, 60, 37])
t = ([1], [2], [3])
test_eq(LSTMOutput()(t), [1])
a = alphabet[np.random.randint(0,3,40)]
b = ALPHABET[np.random.randint(6,10,40)]
c = np.random.rand(40).reshape(4,1,10)
map_a = {k:v for v,k in enumerate(np.unique(a))}
map_b = {k:v for v,k in enumerate(np.unique(b))}
n_embeds = [len(m.keys()) for m in [map_a, map_b]]
szs = [emb_sz_rule(n) for n in n_embeds]
a = np.asarray(a.map(map_a)).reshape(4,1,10)
b = np.asarray(b.map(map_b)).reshape(4,1,10)
inp = torch.from_numpy(np.concatenate((c,a,b), 1)).float()
memb = MultiEmbedding(3, n_embeds, cat_pos=[1,2])
# registered buffers are part of the state_dict() but not module.parameters()
assert all([(k in memb.state_dict().keys()) for k in ['cat_pos', 'cont_pos']])
embeddings = memb(inp)
print(n_embeds, szs, inp.shape, embeddings.shape)
test_eq(embeddings.shape, (inp.shape[0],sum(szs)+1,inp.shape[-1]))