Source code for RADAR.time_series.algorithms.modelsTransformersTS.vanillaTransformer.model
import torch
import torch.nn as nn
import torch.nn.functional as F
from .encoder import Encoder
from .decoder import Decoder
import torch
import torch.nn as nn
[docs]
class Transformer(nn.Module):
"""
Transformer para detección de anomalías en series temporales.
Args:
size_enc_in: dimensión de entrada del encoder
size_dec_in: dimensión de entrada/salida del decoder
ulayers_feedfwd: número de unidades en las capas feedforward
d_qk: dimensión de claves y consultas (Q/K)
d_v: dimensión de los valores (V)
d_model: dimensión interna del modelo
n_layers: número de capas en encoder y decoder
n_heads: número de cabezas de atención
dropout_rate: tasa de dropout
embedding_scale: si escalar embeddings por sqrt(d_model)
attns_outs: si retornar atenciones como salida
"""
def __init__(self,
size_enc_in,
size_dec_in,
ulayers_feedfwd,
seq_len,
d_qk=64,
d_v=64,
d_model=512,
n_layers=6,
n_heads=8,
dropout_rate=0.1,
embedding_scale=False,
attns_outs=False):
super(Transformer, self).__init__()
self.attns_outs = attns_outs
self.encoder = Encoder(
n_layers=n_layers,
feedforward_units=ulayers_feedfwd,
seq_len = seq_len,
n_heads=n_heads,
input_size=size_enc_in,
d_model=d_model,
d_keys=d_qk,
d_values=d_v,
dropout_rate=dropout_rate,
embedding_scale=embedding_scale
)
self.decoder = Decoder(
n_layers=n_layers,
feedforward_units=ulayers_feedfwd,
seq_len = seq_len,
n_heads=n_heads,
input_size=size_dec_in,
d_model=d_model,
d_keys=d_qk,
d_values=d_v,
dropout_rate=dropout_rate,
embedding_scale=embedding_scale
)
self.linear = nn.Linear(d_model, size_dec_in, bias=False)
[docs]
def gen_mask(self, src, tgt):
"""
Genera máscaras para atención en series temporales.
src: [B, L_src, D]
tgt: [B, L_tgt, D]
"""
# Máscara de padding (asume padding con ceros)
src_mask = (src.abs().sum(dim=-1) != 0).unsqueeze(1).unsqueeze(2) # [B,1,1,L_src]
tgt_mask = (tgt.abs().sum(dim=-1) != 0).unsqueeze(1).unsqueeze(2) # [B,1,1,L_tgt]
# Máscara para evitar mirar hacia el futuro en decoder
seq_len = tgt.size(1)
nopeak_mask = torch.tril(torch.ones(1, 1, seq_len, seq_len, device=tgt.device)).bool()
tgt_mask = tgt_mask & nopeak_mask # [B,1,L,L]
return src_mask, tgt_mask
[docs]
def forward(self, enc_inputs, dec_inputs):
"""
enc_inputs: [B, L_enc, D_enc]
dec_inputs: [B, L_dec, D_dec]
"""
mask_enc, mask_dec = self.gen_mask(enc_inputs, dec_inputs)
# Paso por el encoder
enc_out, attns_enc = self.encoder(enc_inputs, mask_enc)
# Paso por el decoder
dec_out, attns_dec, attns_enc_dec = self.decoder(
dec_inputs, enc_out, mask_dec, mask_enc
)
# Proyección final
out = self.linear(dec_out) # [B, L_dec, size_dec_in]
if self.attns_outs:
return out, attns_enc, attns_dec, attns_enc_dec
return out