model.fit()是非常强大的工具,但是自己自定义网络的时候就没法用,但其实tf.keras提供了很好的tf.keras.layers.layer和tf.keras.models.Model使用,通过继承子类的方式也能用上fit
但是,RNN不能这么用,因为RNN还有hiddenstate要搞,直接这么用,会报各种错,反正我吃了血亏了...
tf.keras提供了RNNCell(我用的keras.layers.AbstractRNNCell,还有GRUcell,LSTMcell,都可以用)来继承编写
代码奉上
class myGRU(keras.layers.AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(myGRU, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.dim = input_shape[-1]
self.w_r = self.add_weight(shape=[self.dim+self.units, self.units], initializer='uniform', name='reset_gate', trainable=True)
self.b_r = self.add_weight(shape=[self.units], initializer='zeros', name='reset gate bias', trainable=True)
self.w_z = self.add_weight(shape=[self.dim+self.units, self.units], initializer='uniform', name='update_gate', trainable=True)
self.b_z = self.add_weight(shape=[self.units], initializer='zeros', name='update gate bias', trainable=True)
self.w_n = self.add_weight(shape=[self.dim+self.units, self.units], initializer='uniform', name='intetim', trainable=True)
self.b_n = self.add_weight(shape=[self.units], initializer='zeros', name='intetim gate bias', trainable=True)
self.built = True
def call(self, inputs, states):
prev_output = states[0]
r = tf.nn.sigmoid(inputs @ self.w_r[:self.dim] + prev_output @ self.w_r[self.dim:] + self.b_r)
z = tf.nn.sigmoid(inputs @ self.w_z[:self.dim] + prev_output @ self.w_z[self.dim:] + self.b_z)
n = tf.nn.tanh(inputs @ self.w_r[:self.dim] + prev_output @ (r * self.w_r[self.dim:]) + self.b_n)
output = (1 - z) * prev_output + z * n
return output, output
class GRUReg(keras.models.Model):
def __init__(self, units, **kwargs):
super(GRUReg, self).__init__(**kwargs)
self.gru = keras.layers.RNN(myGRU(units))
self.dense = keras.layers.Dense(1)
def call(self, inputs):
x = self.gru(inputs)
x = self.dense(x)
return x
model = GRUReg(32)
model.compile(optimizer=keras.optimizers.RMSprop(), loss='mae', metrics=['mae'])
model.fit(tf.ones((3, 128, 2)), tf.zeros(3), epochs=5, batch_size=1)