网上关于 Keras 的 RNNLayer 中的输入写的很不清楚,整理如下:

LSTM 的输入

tf.keras.layers.LSTM()参数

文档

1
2
3
4
5
6
7
8
9
10
11
tf.keras.layers.LSTM(
units, activation='tanh', recurrent_activation='sigmoid',
use_bias=True, kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros', unit_forget_bias=True,
kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None,
bias_constraint=None, dropout=0.0, recurrent_dropout=0.0,
return_sequences=False, return_state=False, go_backwards=False, stateful=False,
time_major=False, unroll=False, **kwargs
)

input_dim、input_length、input_shape 的关系

LSTM 的输入是一个三维的张量(numpy narray), 三维张量的 shape 是[samples, time steps, features],也就是[样本数量,时间步长(序列数量),特征长度]。LSTM layer 的参数需要确定其中的两个,在 model.fit 时,就能够对 trainX 进行训练。因此 input_dim 表示单个样本的特征长度,可以用 trainX.shape[2]赋值; input_length 表示的就是时间步长,序列长度,可以用 trainX.shape[1]进行赋值。

另外一种写法是 input_shape,其实就是这两个量的结合:input_shape = (input_length, input_dim)

因此以下的两种写法是等价的:

1
2
model.add(LSTM(units=256, return_sequences=True,
input_dim=trainX.shape[2], input_length=trainX.shape[1]))
1
2
model.add(LSTM(units=256, return_sequences=True,
input_shape=(trainX.shape[1], trainX.shape[2])))

但比较奇怪的是这样设置最终的结果第一维会是 None,最终输出的是[None,timesteps, feature]。如果设置input_size=trainX.size的话,会出现以下错误:
ValueError: Input 0 of layer lstm is incompatible with the layer: expected ndim=3, found ndim=4.
但是如果使用 batch_input_shape=trainX.shape就可以正常运行,并且最终得到训练的每一个样本的 [timesteps,feature]张量。

后来查了keras LSTM的官方文档,它对input的定义是[batch, timesteps, feature],也就是说第一个参数指的是 batch 的大小,如果没有就默认为 None。在model.fit里有一个batch_size,如果设置了该batch_size的值,那么LSTM的层的input会自动根据trainX.shape[0]和batch_size的值来确定每一个输入的batch的大小。

batch size 限制了在可以执行权重更新之前向网络显示的样本数。拟合模型时使用的 batch size 控制一次必须进行多少预测。

GRU 的输入

tf.keras.layers.GRU()参数

1
2
3
4
5
6
7
8
9
10
11
tf.keras.layers.GRU(
units, activation='tanh', recurrent_activation='sigmoid',
use_bias=True, kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros', kernel_regularizer=None,
recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None,
kernel_constraint=None, recurrent_constraint=None, bias_constraint=None,
dropout=0.0, recurrent_dropout=0.0, return_sequences=False, return_state=False,
go_backwards=False, stateful=False, unroll=False, time_major=False,
reset_after=True, **kwargs
)