Migrating a RNN Model to a Transformer
Updated: 5 days ago
To process sequential data, where the position and order of the data is important, a recurrent neural network (RNN) is typically used. This blog explores why a RNN model should be replaced by a transformer, and how to migrate an existing RNN model to a transformer.
Overview of RNN and its Weaknesses
A RNN is made up of neural network cells that have memory. These cells differ from traditional "feed forward" neural network cells, because the output is fed back to the input. The output represents the cell's state, which is a function of the current input data, and the memory of all previous input data.
The weakness of these memory cells is the data from the distant past becomes diluted or forgotten. Thus, RNN models cannot make accurate predictions when the input data is a long sequence.
The feedback mechanism of a RNN has a side effect of a long training time. The cell's state is updated sequentially at each timestep, which is much slower than a feed forward neural network that makes all calculations in parallel.
Addressing RNN Weaknesses
An attention mechanism was developed to address the problem of data dilution. The output state at each timestep is stored in a neural network cell. At the end of the sequence, all the output states are fed to a softmax layer. This layer generates a probability distribution of how much attention to place on each timestep. Using this attention information, the final prediction can be calculated as the weighted sum of the output states.
In summary, the attention mechanism determines how much attention should be placed at each timestep to generate the prediction.
However, the attention mechanism does not address the problem of the long training time.
Overview of Transformers
Instead of using neural network cells with memory, the transformer uses only feed forward cells. The position of the input data in the sequence is preserved by using a positional encoding at each timestep. This data is then fed to an attention mechanism, which is conceptually similar to the attention mechanism used in a RNN.
Because the transformer addresses both weaknesses of a RNN, the transformer is not only much faster to train, it also produces more accurate results. The only downside is that the transformer is not natively supported in Keras.
Migrating a RNN model to a Transformer
Currently, the only viable transformer model is in this Tensorflow tutorial. This tutorial walks through each component, so that the component can be modified as needed.
The custom optimizer, loss function, and training function is optional. These can be replaced by the default Adam optimizer, 'categorical_crossentropy' loss function, and the Keras default Model.fit() function to train the model.
The masks can be integrated into the transformer model, instead of passing the masks into the transformer. In addition, only the look ahead mask is mandatory; the padding masks are optional, depending on the input data.
Thank you for reading. I hope you find this guide helpful for migrating a RNN model to a transformer.
Questions or comments? You can reach me at firstname.lastname@example.org
Wayne Cheng is an A.I., machine learning, and deep learning developer at Audoir, LLC. His research involves the use of artificial neural networks to create music. Prior to starting Audoir, LLC, he worked as an engineer in various Silicon Valley startups. He has an M.S.E.E. degree from UC Davis, and a Music Technology degree from Foothill College.