Recurrent Neural Networks: RNN, LSTM, and GRU

Introduction
In this article, we will explore the fascinating world of Recurrent Neural Networks (RNNs), a fundamental technology in artificial intelligence and machine learning. RNNs are unique due to their ability to process and analyze sequences of data, making them invaluable tools in fields ranging from speech recognition to time series analysis.
The essence of RNNs lies in their capacity to maintain a kind of ‘memory’ about previous inputs. This sets them apart from traditional neural networks, which process each input independently, without considering the order or sequence of the data. In this context, RNNs can be seen as machines that have a continuous understanding of context, something crucial for tasks such as natural language processing or predicting trends based on historical data.
Throughout this text, we will dive into the specifics of RNNs, starting with an explanation of how sequential data is fundamental to their operation. We will discuss the challenges and strategies for training these networks, address innovations in RNN architectures such as Long Short-Term Memory (LSTM), and illustrate how these technologies are implemented and used in real-world applications.
Prepare yourself for a detailed and informative journey, ideal for both AI enthusiasts and professionals in the field, aiming to unravel the mysteries and capabilities of Recurrent Neural Networks.
Sequential Data
Massive Sequential Data represents a fundamental concept in the field of data science, being particularly interesting for those who are beginning to explore this area. Imagine a long line of people, where each person represents a data point, and this line is constantly moving. Each person (or data point) has a relationship with the person in front and behind them, creating a continuous sequence. This is the essence of Sequential Data: a series of information collected over time, where each part is significantly connected to the next.
To better understand these data, scientists use something called Training Data Sets. Think of this as choosing specific individuals from the line to study more closely. For example, a scientist might want to observe how temperature changes every hour of the day. By selecting just these specific points, they can focus on important patterns and trends, without being overwhelmed by the massive amount of data.
Sequential Data is not limited to a single field; it is used in various areas. In linguistics, for example, each word in a sentence follows another in a logical order, forming a coherent text. In medicine, a patient’s journey from hospital admission to discharge is a sequence of interconnected events. This interconnection is vital to understand the data as a whole, instead of seeing each data point in isolation.
A crucial part of working with Sequential Data is understanding Temporal Dependence. This means that what happened in the past in a sequence can affect what happens in the future. For example, medical decisions are often based on a patient’s health history, not made randomly. This relationship between past and future helps make sense of large amounts of data, allowing for more accurate predictions and insights.
Finally, a practical example of the importance of Sequential Data can be found in stock price forecasting using autoregressive models. These models analyze the history of stock prices to try to predict what will happen next. It’s like trying to guess the next number in a sequence, based on the previous numbers. This concept shows how understanding and analyzing Sequential Data is essential in various areas, from science to finance, providing a solid foundation for those who are beginning to explore the world of data science.
Neural Networks Without Hidden States
The term ‘without hidden states’ refers to a neural network that does not maintain any information about previous inputs beyond the current one. In other words, it does not have memory of the data that has been processed before. I’ll explain this in more detail:
In the context of neural networks, a ‘hidden state’ usually refers to the network’s ability to retain some kind of internal state or memory that carries information across sequences of data. This is common in Recurrent Neural Networks (RNNs), where hidden states transmit information from one processing step to the next, allowing the network to make use of the context or sequential order of the data.
On the other hand, a neural network like the MLP (Multi-Layer Perceptron), without hidden states, treats each input independently, without taking into account the order or sequence. Each input (for example, an image or a set of features) is processed in isolation. This type of network is suitable for tasks where the order of the input data is not important.
To exemplify:
- With Hidden States: Imagine a neural network that is trying to understand a sentence. If it has hidden states, like an RNN, it can remember the previous word while processing the current word, which is very useful for understanding the full meaning of the sentence.
- Without Hidden States: Now, imagine you provide this network with a photo to classify whether it’s a cat or a dog. The network doesn’t need to remember other photos it saw before; it only needs to analyze the current photo. This is how an MLP without hidden states operates.
Therefore, ‘without hidden states’ means that the neural network does not have the capability to remember previous inputs, and each decision is made based only on the current input.
Recurrent Neural Networks
Recurrent Neural Networks (RNNs) have the ability to process sequences of data, such as sentences or time series. Unlike traditional neural networks, like the MLP, RNNs possess a kind of short-term memory, storing information from previous steps in the sequence. This memory is realized through “hidden states,” which allow the network to make connections over time.
The equation that defines how an RNN operates is an excellent way to understand its operation:
$$ H_t = \phi(W_{xh} \cdot X_t + W_{hh} \cdot H_{t-1} + b_h) $$
In this equation:
- $ H_t $ is the hidden state at the current moment, representing the network’s current memory.
- $ \phi $ is the activation function, like the hyperbolic tangent or ReLU, introducing non-linearity to the model.
- $ X_t $ are the current input data.
- $ W_{xh} $ and $ W_{hh} $ are the weights that connect, respectively, the current input and the previous hidden state to the current hidden state.
- $ b_h $ is the bias for the hidden states.
The process of recurrence in RNNs involves using the previous hidden state $ H_{t-1} $ to generate the new state $ H_t $, along with the current input $ X_t $. This mechanism allows the RNN to maintain a contextual record over time, updating its “memory” based on the new input and the information from the previous step.
Furthermore, despite RNNs processing sequences of different lengths, they maintain a constant amount of parameters, meaning the network can handle long sequences without increasing the model’s complexity.
The main advantage of RNNs over neural networks without hidden states is their ability to work with sequences where order and context are important. They are particularly useful for tasks like speech recognition, language translation, and time series forecasting. The ability to store and process sequential information makes RNNs powerful tools for many applications in artificial intelligence.
Implementing Recurrent Neural Networks
Recurrent Neural Networks (RNNs) are fundamental in many machine learning applications, especially those involving sequential data like natural language processing or time series. An RNN is unique in its ability to maintain an internal state, capturing information about previously processed data, essential for tasks involving temporal dependencies.
Structure of the CustomRNN Class
In PyTorch, a machine learning library for Python, a custom RNN can be implemented by extending the nn.Module
class. The CustomRNN
class, as defined in the provided code, exemplifies a basic RNN implementation. It is initialized with input_size
, which defines the dimension of the input data, and hidden_size
, which is the dimension of the hidden state.
Initialization of Parameters
The custom RNN utilizes three main sets of parameters:
- Input-to-Hidden State Weights (
self.Wxh
): Connect each input to the hidden state layer. - Hidden-to-Hidden State Weights (
self.Whh
): Allow the RNN to maintain information over time. - Bias for the Hidden State (
self.bh
): Contributes to the model’s ability to fit the data.
Sequence Processing
During the forward pass (forward
), the RNN processes an input sequence (X_sequence
) using a series of calculations:
- Each element of the sequence is processed individually.
- The hidden state is updated at each time step based on the current input and the previous hidden state.
- The
tanh
function introduces non-linearity, allowing the network to learn complex patterns. - The hidden states are stored and then concatenated to form the final output of the sequence.
import torch
import torch.nn as nn
class CustomRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(CustomRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# Initializing weights and bias
self.Wxh = nn.Parameter(torch.randn(input_size, hidden_size))
self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size))
self.bh = nn.Parameter(torch.randn(1, hidden_size))
def forward(self, X_sequence, Ht):
outputs = []
# Processing the sequence
for t in range(X_sequence.size(0)):
Xt = X_sequence[t]
Ht = torch.tanh(torch.mm(Xt, self.Wxh) +
torch.mm(Ht, self.Whh) + self.bh)
# Adding an extra dimension for seq_len
outputs.append(Ht.unsqueeze(0))
# Concatenating the output for each time step into a single tensor
outputs = torch.cat(outputs, dim=0)
return outputs, Ht
# Parameters
seq_length = 3 # Sequence length
batch_size = 1
input_size = 5 # Input data size
hidden_size = 10 # Hidden state size
# Instantiating the custom RNN
custom_rnn = CustomRNN(input_size, hidden_size)
# Initializing input data for the complete sequence
X_sequence = torch.randn(seq_length, batch_size, input_size)
hidden_state = torch.zeros(batch_size, hidden_size) # Initial hidden state
# Passing the sequence through the RNN
output, hidden_state = custom_rnn(X_sequence, hidden_state)
# Printing the output and hidden state
print("Output:", output)
print("Hidden state:", hidden_state)
Training Recurrent Neural Networks (RNNs): Strategies and Challenges
Recurrent Neural Networks (RNNs) are a type of artificial neural network specialized in processing sequences of data, such as time series or natural language. However, dealing with long sequences of data can be challenging for these networks. When an RNN tries to remember many previous data points, it can face issues due to the large number of mathematical operations involved. This not only demands a lot from the computer’s memory but can also lead to calculation errors, a phenomenon known as numerical instability. Imagine trying to remember a long shopping list without writing anything down – eventually, you might start to forget or mix up the items.
A critical problem in RNNs is the vanishing gradient, where the gradient values decrease exponentially during training, making it difficult to update the weights for the early elements in the sequence. Gradient explosion, on the other hand, occurs when the gradient values grow exponentially, potentially leading to numerical instabilities.
To circumvent these challenges, researchers have developed advanced architectures like LSTM and GRU networks, which implement gate mechanisms to control the flow of information, allowing the network to retain relevant information for longer and discard unnecessary information. Additionally, techniques such as gradient clipping are widely used. This approach is similar to an RNN deciding to focus only on the most recent parts of the sequence, ignoring old data. This helps keep calculations manageable and minimizes errors, without losing much accuracy in predictions.
There are different ways to make these adjustments. One is complete calculation, which considers the contribution of each time step in updating the weights, but it is a slow and complicated process. Other techniques include time-step truncation and random truncation. Time-step truncation limits the analysis to the most recent events, while random truncation uses a random variable to decide when to stop considering previous time steps.
Although random truncation may seem more accurate theoretically, in practice, it is not necessarily more effective than regular truncation. Often, looking just a few steps back is enough to capture the most important dependencies. Moreover, the variability introduced by randomness may not outweigh the benefits of being a bit more precise. Interestingly, models focusing on a shorter time interval can be more effective, as this can act as a form of regularization, helping the network generalize better to new data.
LSTM (Long Short-Term Memory)
LSTM, which stands for “Long Short-Term Memory,” is an advanced technique in Artificial Intelligence, particularly useful for dealing with neural networks, akin to the brains of computers. It was created to overcome a specific challenge found in Recurrent Neural Networks (RNNs).
RNNs are excellent for processing sequences of data, such as a series of numbers or words in a sentence. However, they face a significant problem known as “vanishing and exploding gradients.” This means that when trying to learn from very long data, RNNs either lose important information or the information becomes excessively amplified. Imagine trying to recall a long story: the longer the story, the harder it is to remember the details from the beginning, or you might end up focusing too much on a specific part.
LSTM was developed to solve this problem. It is an improved version of RNNs, equipped with a special structure called “memory cell.” These memory cells are like mini-information storages, capable of retaining important information for a long period and discarding unnecessary information. This allows the LSTM to maintain a balance, avoiding the loss or excessive amplification of information when processing sequences of data.
The “long and short-term memory” in the name of LSTM comes from its ability to manage different types of memory. Standard RNNs have long-term memory, stored in their weights (adjustments in the model), and short-term memory, in the activations (temporary information passed between nodes). LSTM, with its memory cells, adds an intermediate level of storage, helping the model to maintain relevant information and discard irrelevant ones during data processing.
Input Gate, Forget Gate, and Output Gate
LSTM Neural Networks (Long Short-Term Memory) process sequences of data, like texts or time series, using a system of gates that decide what to remember and what to forget. These gates operate in a coordinated manner to manage the cell’s memory in three main aspects:
-
Input Gate: This gate determines which new information is relevant and should be stored in the memory. The associated equation, $ I_t = \sigma(X_t W_{xi} + H_{t-1} W_{hi} + b_i) $ , uses the sigmoid function $ \sigma $ to evaluate the importance of new information, transforming the values into a range between 0 and 1. Values closer to 1 indicate high relevance.
-
Forget Gate: Responsible for evaluating which old information is no longer useful and should be discarded. The equation $ F_t = \sigma(X_t W_{xf} + H_{t-1} W_{hf} + b_f) $ also uses the sigmoid function to determine which data should be retained (values close to 1) or discarded (values close to 0).
-
Output Gate: This gate decides which memory information will be used in calculating the network’s output at that moment. The equation $ O_t = \sigma(X_t W_{xo} + H_{t-1} W_{ho} + b_o) $ determines which stored information is relevant for representing the current state.
Each of these gates considers the current input $ X_t $ (for example, a new word in a sentence), the previous state $ H_{t-1} $ (what was processed before), and performs mathematical operations (like multiplications and additions) with weights and biases adjusted during the network’s training.
The weights and biases $ W $ and $ b $ in each equation are adjusted during the training of the model to optimize performance. They determine the importance of the inputs and the previous state in each decision. The sigmoid functions in the equations help model binary decisions, facilitating the choice between retaining or discarding information, as sigmoid functions return values between 0 (zero) and 1 (one), where values closer to zero should be forgotten, and those closer to one should be remembered.
Input Node
The input node, or $ \tilde{C}_t $, in an LSTM architecture is where new candidate values for the cell state are generated. The equation:
$$ \tilde{C}_t = \tanh ( X_t W_xc + H_t-1 W_hc + b_c ) $$
is central to this function. Here, $ \tanh $ is the hyperbolic tangent activation function that helps normalize the inputs between -1 and 1, enabling the model to capture non-linear relationships. $ W_{xc} $ and $ W_{hc} $ are weight matrices that transform, respectively, the current input $ X_t $ and the previous hidden state $ H_{t-1} $. The bias vector $ b_c $ allows for fine-tuning in the transformation.
This node is essential for several reasons:
-
Generation of New Candidates: It combines current information and previous learnings to propose updates to the cell state, allowing the LSTM to consider new information while maintaining previous knowledge.
-
Capturing Complex Relationships: Through the $ \tanh $ function, the input node can process and prepare complex information, which is essential for subsequent decisions about which data to keep or discard.
-
Selective Updating: The values generated here are weighted by the input gate, meaning that only information considered useful is retained, preserving the LSTM’s ability to maintain memory over sequences of data for extended periods.
Therefore, the input node is a crucial component that ensures the LSTMs’ ability to selectively and controlled update their memory, which is vital for tasks that require understanding and processing sequences of data, such as language translation or time series forecasting.
Internal State of the Memory Cell
The internal state of a memory cell in an LSTM network is a way for the neural network to store information over time. Each memory cell in an LSTM has the capability to maintain a record of what happened in previous time steps. Think of it as short-term memory for the network.
Here are the key points to understand the internal state of a memory cell:
-
Information Storage: The internal state holds important information that the neural network needs to remember. This information is used for making predictions or decisions in subsequent steps.
-
Persistence: Unlike traditional neural networks, where information flows in only one direction and is lost after each step, the internal state of the LSTM can retain information for a long period. This is crucial for tasks such as natural language processing, where understanding the past context (previous words or sentences) is necessary to interpret the present.
-
Selective Updating: The internal state can be selectively updated. This means that at each time step, the network decides based on the new input and the previous hidden state what should be remembered (updated) and what should be forgotten.
-
Influence on Output: The internal state, along with the hidden state, influences the final output of the LSTM. For example, in a text prediction task, the internal state can help determine the next word based on the previous words.
The equation for the internal state of a memory cell in an LSTM is given by:
$$ C_t = F_t \odot C_{t-1} + I_t \odot \tilde{C}_t $$
Here is what each term in the equation represents and how it works simply:
-
$ C_t $: This is the new internal state of the memory cell at the current time $ t $. It is the result of combining what the cell decides to keep from the previous memory and what it decides to add from the newly proposed information.
-
$ F_t $: The forget gate decides how much of the previous memory $ C_{t-1} $ will be retained. If $ F_t $ is 0, it means forgetting everything from the previous memory; if it is 1, it means keeping everything.
-
$ \odot $: This is the symbol for the Hadamard product, i.e., element-wise multiplication. So, $F_t \odot C_{t-1} $ means that each element of the previous cell state is multiplied by the corresponding element of the forget gate.
-
$ C_{t-1} $: The internal state of the memory cell at the previous time ’t-1’. It contains the information that was stored up to that point.
-
$ I_t $: The input gate decides how much of the new candidate information $\tilde{C}_t $ will be added to the memory cell.
-
$ \tilde{C}_t $: This is the new candidate cell state proposed, based on the current information and the previous hidden state.
-
$ I_t \odot \tilde{C}_t $: Similar to the Hadamard product with the forget gate, here the new candidate information is multiplied by the input gate. This determines how much of the new information will be added to the cell state.
Why It Works?
-
Flexibility: This equation allows the LSTM to dynamically adjust how much it wants to remember old information and how much it wants to add new information. This gives the network the ability to retain relevant information over time and discard unnecessary information.
-
Long-Term Memory: This memory cell structure is what allows LSTMs to be used in tasks requiring long-term memory, such as natural language processing or time series, because they can remember information over many time steps.
Hidden State
The hidden state carries the information that is passed forward in the neural network. The cell state $C_t $ contains long-term memory, while the hidden state $H_t $ is the short-term memory that also serves as the output of the LSTM at a particular time step.
The function of the output gate $O_t $ is like a regulator controlling when and how much of the long-term memory in the cell state is transferred to the hidden state. If the output gate is open (values close to 1), the memory flows to the hidden state and affects the network’s output. If it is closed (values close to 0), the memory is retained within the cell and does not immediately affect the output.
This mechanism allows the LSTM to decide what is important to pass on to the next layers or time steps, enabling the retention of relevant information and the suppression of unnecessary information, which is crucial for processing sequences of data and tasks that rely on temporal context.
The equation for the hidden state is:
$$ H_t = O_t \odot \tanh(C_t) $$
Where:
- $H_t $: Is the hidden state at time ’t’, which is what the LSTM passes to the next layer or the next time step.
- $O_t $: Is the output gate at time ’t’. It decides how much of the internal state is revealed to the hidden state. When $O_t $ is close to 1, it allows the information to flow out freely. When it is close to 0, it blocks the flow of information, keeping it within the cell.
- $\odot $: Represents element-wise multiplication, known as the Hadamard product.
- $\tanh(C_t) $: Is the hyperbolic tangent function applied to the cell state at time $ t $ $C_t $. It helps to normalize the values of the cell state to be between -1 and 1.
Implementing an LSTM (Long Short-Term Memory)
The provided code defines a custom LSTM (Long Short-Term Memory) cell using the PyTorch library, one of the most popular deep learning frameworks. The LSTM cell is a fundamental unit of an LSTM network, widely used in time series problems and natural language processing due to its ability to capture long-term dependencies.
import torch
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# Parameters for the input gate
self.W_ii = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_ii = nn.Parameter(torch.Tensor(hidden_size))
self.b_hi = nn.Parameter(torch.Tensor(hidden_size))
# Parameters for the forget gate
self.W_if = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_if = nn.Parameter(torch.Tensor(hidden_size))
self.b_hf = nn.Parameter(torch.Tensor(hidden_size))
# Parameters for the input node
self.W_ig = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.W_hg = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_ig = nn.Parameter(torch.Tensor(hidden_size))
self.b_hg = nn.Parameter(torch.Tensor(hidden_size))
# Parameters for the output gate
self.W_io = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_io = nn.Parameter(torch.Tensor(hidden_size))
self.b_ho = nn.Parameter(torch.Tensor(hidden_size))
# Parameter initialization
self.init_parameters()
def init_parameters(self):
# Initializes parameters with suitable distributions
for p in self.parameters():
if p.data.ndimension() >= 2:
nn.init.xavier_uniform_(p.data)
else:
nn.init.zeros_(p.data)
def forward(self, inputs, init_states=None):
"""
Forward method to process a sequence of inputs over time.
:param inputs: Sequence of input tensors.
:param init_states: Initial states (hidden state and cell state).
:return: Outputs over time and the last state (hidden state and cell state).
"""
if init_states is None:
h_t = torch.zeros(
(inputs.shape[1], self.hidden_size), device=inputs.device)
c_t = torch.zeros(
(inputs.shape[1], self.hidden_size), device=inputs.device)
else:
h_t, c_t = init_states
outputs = []
# Iterates through the sequence of inputs
for x in inputs:
# Input gate: decides which information will be updated
i_t = torch.sigmoid(x @ self.W_ii.t() +
self.b_ii + h_t @ self.W_hi.t() + self.b_hi)
# Forget gate: decides which information will be discarded from the cell state
f_t = torch.sigmoid(x @ self.W_if.t() +
self.b_if + h_t @ self.W_hf.t() + self.b_hf)
# Input node: creates a vector of new candidates to be added to the cell state
g_t = torch.tanh(x @ self.W_ig.t() + self.b_ig +
h_t @ self.W_hg.t() + self.b_hg)
# Updating the internal state of the memory cell
c_t = f_t * c_t + i_t * g_t
# Output gate: decides which parts of the cell state will be outputs
o_t = torch.sigmoid(x @ self.W_io.t() +
self.b_io + h_t @ self.W_ho.t() + self.b_ho)
# Hidden state: is the output of the LSTM, using the cell state passed through a tanh function
h_t = o_t * torch.tanh(c_t)
outputs.append(h_t)
return outputs, (h_t, c_t)
# Input size and hidden size for demonstration
input_size = 10
hidden_size = 20
# Instance of the LSTM cell
lstm_cell = LSTMCell(input_size, hidden_size)
# Example of input sequence and initial states
inputs = torch.randn(5, 1, input_size) # Sequence of inputs
init_states = (torch.zeros(1, hidden_size), torch.zeros(
1, hidden_size)) # Null initial states
# Outputs over time and the last states
outputs, (h_n, c_n) = lstm_cell(inputs, init_states)
outputs
The LSTMCell
class inherits from nn.Module
, which is the base of all neural network modules in PyTorch. The __init__
constructor initializes the LSTM cell’s parameters, which are:
- Parameters for the input gate: responsible for deciding which information will be updated in the cell state.
- Parameters for the forget gate: determine which information from the cell state will be discarded, helping to avoid the vanishing gradient problem.
- Parameters for the input node: create a vector of new candidates that can be added to the cell state.
- Parameters for the output gate: select the parts of the cell state that will be used to compute the final hidden state.
The parameters are weight matrices and bias vectors that will be learned during the network’s training. The init_parameters
function is used to initialize these parameters with appropriate values, crucial for the neural network’s good performance.
The forward
method is the heart of the LSTM cell, where the logic for processing a sequence of inputs over time is implemented. If no initial state is provided, the hidden state and cell state are initialized with zeros. Then, for each element in the input sequence, the input, forget, and output gates are calculated using sigmoid activation functions, and the input node is calculated with the hyperbolic tangent activation function. These values are used to update the cell state and compute the new hidden state. The resulting hidden state is collected in an output list.
At the end of the code, an instance of the LSTM cell is created with specified input and hidden sizes. A sequence of random inputs is passed to the LSTM cell along with null initial states, and the LSTM cell processes the sequence, returning the outputs over time and the final states.
GRU (Gated Recurrent Units)
LSTMs can be computationally intensive, meaning they require considerable processing power and time to learn. Because of this, researchers began looking for simpler and more efficient alternatives. One such alternative is the GRU.
The GRU is a variation of the LSTM that is also capable of capturing long-term dependencies in data, but in a more efficient manner. It does this using a more streamlined structure that combines two of the gates found in LSTMs into a single one, reducing the number of mathematical operations needed and, consequently, the computation time.
Update Gate and Reset Gate
GRU (Gated Recurrent Unit) and LSTM (Long Short-Term Memory) neural networks stand out in processing sequences, such as in natural language processing and time series analysis. The Update Gate and Reset Gate are essential mechanisms that control the flow of information in these networks.
Understanding the Reset Gate:
The Reset Gate helps determine the amount of past information that will be used to calculate the current state.
When the GRU processes a sequence of data, it maintains a hidden state that carries information across different points in time. This hidden state is crucial for the network to make predictions or decisions based not only on the current input but also on what it has previously learned.
The functioning of the Reset Gate can be understood as follows:
-
If the Reset Gate is close to 1 (i.e., when the sigmoid activation function returns a high value), it means that the previous hidden state $ H_{t-1} $ is considered important, and therefore, much of the information it contains should be “remembered” or retained, strongly influencing the next hidden state.
-
If the Reset Gate is close to 0 (i.e., the sigmoid activation function returns a low value), it indicates that the previous state is not as relevant to the current moment in the sequence. In this case, the network “forgets” or discards part of what it knew before, allowing the new hidden state to be formed with less influence from the past.
Thus, the Reset Gate acts as a regulator that can dynamically forget unnecessary past information or preserve important aspects to contribute to future decisions or predictions, based on the relevance of past information to the current task. The updated equation for the Reset Gate is:
$$ R_t = \sigma(X_t W_{xr} + H_{t-1} W_{hr} + b_r) $$
Where $ R_t $ is the vector of the Reset Gate at time $ t $, $ \sigma $ is the sigmoid function, $ X_t $ is the input vector, $ W_{xr} $ is the weight matrix from the input to the Reset Gate, $ H_{t-1} $ is the previous hidden state, $ W_{hr} $ is the weight matrix from the hidden state to the Reset Gate, and $ b_r $ is the bias. This gate allows the neural network to efficiently decide what to forget, especially when significant changes occur in the data.
Exploring the Update Gate:
The Update Gate controls how much of the previous state we might still want to remember.
The Update Gate operates in the following way:
-
When the value of the Update Gate is close to 1, it means the network decides to retain most of the previous hidden state $ H_{t-1} $. This is useful in situations where the past information is relevant for predicting or understanding the next step in the sequence. For example, if the sequence is a sentence where the next term strongly depends on the context provided by the previous terms, the Update Gate will allow this context to be maintained.
-
On the other hand, if the value of the Update Gate is close to 0, the network is deciding to update its hidden state with new information, placing less emphasis on what was learned previously. This is beneficial when the new input data contains sufficient information for the current prediction or decision, and the previous history is less important.
The Update Gate, therefore, plays a critical role in determining how past information is merged with new inputs to form the current state. This allows the GRU to be adaptable to the context of the data sequence, maintaining the flexibility to retain critical information or adapt to new patterns as they emerge. The adjusted equation is:
$$ Z_t = \sigma(X_t W_{xz} + H_{t-1} W_{hz} + b_z) $$
Here, $ Z_t $ is the vector of the Update Gate, $ X_t $ is the input vector, $ W_{xz} $ and $ W_{hz} $ are the weight matrices for the input and the previous hidden state, respectively, and $ b_z $ is the bias. This gate plays a crucial role in preserving important information over time, allowing for continuity and coherence in the data sequence.
Candidate Hidden State
The candidate hidden state $ \tilde{H_t} $ is a provisional version of the new hidden state that the GRU is attempting to construct at a given moment. This state is calculated using the current input $ X_t $ and the previous hidden state $ H_{t-1} $, both influenced by their respective weight matrices $ W $ and $ U $, and a bias vector $ b $. The formula is:
$$ \tilde{H_t} = \tanh(X_t W_{xh} + (R_t \odot H_{t-1}) W_{hh} + b_h) $$
Here:
- $ \tanh $ is the hyperbolic tangent activation function, which transforms the input values into a new set of values between -1 and 1.
- $ \odot $ is the Hadamard product operator, meaning the multiplication is done element-wise.
- $ R_t $ is the vector resulting from the Reset Gate, which decides how much of the previous state $ H_{t-1} $ will be considered when calculating the new candidate for the hidden state.
Based on the influence of the Reset Gate, the candidate hidden state is created. Then, the Update Gate will decide how much of the previous hidden state will be retained and how much of the new candidate hidden state will be used to form the final hidden state $ H_t $.
Hidden State
The update of the hidden state $ H_t $ in a GRU is carried out using the following equation:
$$ H_t = Z_t \odot H_{t-1} + (1 - Z_t) \odot \tilde{H}_t $$
- $ H_t $: Hidden state at time $ t $, which will be passed to the next time step and to the output layer if needed.
- $ H_{t-1} $: Hidden state at time $ t-1 $, i.e., the hidden state from the previous step.
- $ \tilde{H}_t $: Candidate hidden state at time $ t $, calculated based on the current input and the previous hidden state, after being modified by the reset gate.
- $ Z_t $: Update gate at time $ t $, which determines how much of the previous state will be retained.
- $ \odot $: Hadamard product operator, meaning element-wise multiplication.
This equation determines how the new hidden state $ H_t $ is calculated:
-
Update Gate $ Z_t $: Determines the proportion of the previous state $ H_{t-1} $ that should be retained. The closer it is to 1, the more of the previous state is kept.
-
Previous Hidden State $ H_{t-1} $: Is weighted by the update gate $ Z_t $ to decide how much of this previous state should be kept.
-
Candidate Hidden State $ \tilde{H}_t $: Is a new proposed hidden state generated based on the current input and the previous state $ H_{t-1} $, adjusted by the reset gate. It is weighted by $ 1 - Z_t $, indicating that if the update gate favors the previous state, the impact of the candidate state will be lesser, and vice versa.
-
Final Combination: The updated hidden state $ H_t $ is the combination of the previous state adjusted by the update gate with the candidate state adjusted by the inversion of the update gate. This allows the GRU to decide whether to maintain old information (based on relevance determined by the update gate) or replace it with a new state proposal.
This mechanism enables the GRU to adapt to different time-dependency requirements, retaining important information from previous steps of the sequence or updating it with new information as needed.
Implementing a GRU (Gated Recurrent Units)
The following code snippet is a custom implementation of a GRU cell using PyTorch.
import torch
from torch import nn
class GRUCell(nn.Module):
def __init__(self, input_size, num_hiddens):
super(GRUCell, self).__init__()
self.num_hiddens = num_hiddens
# Parameters for the update gate
self.W_xz = nn.Parameter(torch.randn(input_size, num_hiddens))
self.W_hz = nn.Parameter(torch.randn(num_hiddens, num_hiddens))
self.b_z = nn.Parameter(torch.zeros(num_hiddens))
# Parameters for the reset gate
self.W_xr = nn.Parameter(torch.randn(input_size, num_hiddens))
self.W_hr = nn.Parameter(torch.randn(num_hiddens, num_hiddens))
self.b_r = nn.Parameter(torch.zeros(num_hiddens))
# Parameters for the candidate hidden state
self.W_xh = nn.Parameter(torch.randn(input_size, num_hiddens))
self.W_hh = nn.Parameter(torch.randn(num_hiddens, num_hiddens))
self.b_h = nn.Parameter(torch.zeros(num_hiddens))
# Parameter initialization
self.init_parameters()
def init_parameters(self):
# A simple initialization procedure
for p in self.parameters():
if p.data.ndimension() >= 2:
nn.init.xavier_uniform_(p.data)
else:
nn.init.zeros_(p.data)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = torch.zeros(
(inputs.shape[1], self.num_hiddens), device=inputs.device)
outputs = []
for X in inputs:
# Calculation of the update gate
Z = torch.sigmoid(torch.matmul(X, self.W_xz) +
torch.matmul(H, self.W_hz) + self.b_z)
# Calculation of the reset gate
R = torch.sigmoid(torch.matmul(X, self.W_xr) +
torch.matmul(H, self.W_hr) + self.b_r)
# Calculation of the candidate hidden state
H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +
torch.matmul(R * H, self.W_hh) + self.b_h)
# Updating the hidden state
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return torch.stack(outputs), H
# Example of usage:
# Define the input size and number of hidden units
input_size = 5 # Example input feature dimension
num_hiddens = 10 # Number of hidden units
# Create a GRU cell
gru_cell = GRUCell(input_size, num_hiddens)
# Create some dummy input data
# (sequence_length, batch_size, input_size)
inputs = torch.randn(3, 1, input_size)
# Forward pass through the GRU cell
outputs, H = gru_cell(inputs)
outputs
Conclusion
Throughout this article, we discussed the concepts and applications of Recurrent Neural Networks, including their advanced variations like LSTM and GRU. We observed how these technologies are crucial for understanding and processing sequential data, offering innovative solutions to complex challenges across various fields, from natural language processing to time series analysis.
RNNs, with their unique ability to maintain a “memory” of previous data, have revolutionized the way we handle data sequences. LSTMs and GRUs, with their gate mechanisms, took this capability even further, allowing for more refined control over what is remembered and what is forgotten, making the processing of long and complex sequences more efficient and effective.
The implementation of these networks, as demonstrated through the PyTorch code examples, illustrates the flexibility and adaptability of these technologies in practical applications. With the LSTM and GRU examples, we saw how it is possible to construct and train neural networks that can capture and utilize long-term dependencies in data, an essential ability for many current AI tasks.
In summary, Recurrent Neural Networks, LSTMs, and GRUs are powerful tools in the arsenal of Artificial Intelligence, offering the capability to process and interpret sequences of data effectively. Their continued use and evolution will likely open new horizons and possibilities in the field of machine learning and beyond, contributing significantly to the advancement of technology and its application in a variety of real-world domains.