Here are a few questions related to Transformers masks in pytorch.
My model will need to be auto-regressive. I understand that this is possible. In the original paper and here, is is well discussed that " The normal Transformer decoder is autoregressive at inference time and non-autoregressive at training time."
My model will have a very specific type of generation pattern, which does not really exist in the literature: during training, the whole input and output sequences are known. But during generation, I will use each output (a choice of direction) to prepare the next input point (a small patch of image at this coordinate), starting from a random patch. This is similar to the image below, but the inputs would be some function of y; x(^y<t>)
. This was easy to do using RNNs. I am not sure how to translate this to Transformers, using the description of masks in the doc or in this answer.
- I'm a bit lost concerning the masks:
- At each iteration of the generation process, I only know one more input point than output point. My guess is that I could maybe use
src_mask = torch.triu([...], diagonal=1)
? - The target given as input to the decoder is usually shifted. Here too, for the first src (a patch), the input to the decoder should be some type of "start of sequence" token. However, I work with continuous values (real-value coordinates), not with dictionaries of words, as we always see in every example. My guess is that instead I could maybe use
tgt_mask = torch.triu([...], diagonal=0)
, meaning that at the first step, the mask will simply be False?
- Considering my specific needs, I will prepare my own loop to infer each point one after the other. I know that there are examples on the web, but they run the whole sequence only to get the last point. Memory is as challenge for me. Is there a way that I can use torch's model to output only the last point at each iteration? I saw something about it in huggingface, where they return the hidden state, but no example on how to use them, and also, so far, I was using the
torch.nn.modules.transformer
, I am a little more familiar. Is is possible using torch's rnn modules?
Thank you!