model.fit()是非常强大的工具,但是自己自定义网络的时候就没法用,但其实tf.keras提供了很好的tf.keras.layers.layertf.keras.models.Model使用,通过继承子类的方式也能用上fit

但是,RNN不能这么用,因为RNN还有hiddenstate要搞,直接这么用,会报各种错,反正我吃了血亏了...

tf.keras提供了RNNCell(我用的keras.layers.AbstractRNNCell,还有GRUcell,LSTMcell,都可以用)来继承编写

代码奉上

class myGRU(keras.layers.AbstractRNNCell):
    def __init__(self, units, **kwargs):
        self.units = units
        super(myGRU, self).__init__(**kwargs)

    @property
    def state_size(self):
        return self.units

    def build(self, input_shape):
        self.dim = input_shape[-1]

        self.w_r = self.add_weight(shape=[self.dim+self.units, self.units], initializer='uniform', name='reset_gate', trainable=True)
        self.b_r = self.add_weight(shape=[self.units], initializer='zeros', name='reset gate bias', trainable=True)

        self.w_z = self.add_weight(shape=[self.dim+self.units, self.units], initializer='uniform', name='update_gate', trainable=True)
        self.b_z = self.add_weight(shape=[self.units], initializer='zeros', name='update gate bias', trainable=True) 

        self.w_n = self.add_weight(shape=[self.dim+self.units, self.units], initializer='uniform', name='intetim', trainable=True)
        self.b_n = self.add_weight(shape=[self.units], initializer='zeros', name='intetim gate bias', trainable=True)

        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]

        r = tf.nn.sigmoid(inputs @ self.w_r[:self.dim] + prev_output @ self.w_r[self.dim:] + self.b_r)
        z = tf.nn.sigmoid(inputs @ self.w_z[:self.dim] + prev_output @ self.w_z[self.dim:] + self.b_z)
        n = tf.nn.tanh(inputs @ self.w_r[:self.dim] + prev_output @ (r * self.w_r[self.dim:]) + self.b_n)
        
        output = (1 - z) * prev_output + z * n

        return output, output

class GRUReg(keras.models.Model):
    def __init__(self, units, **kwargs):
        super(GRUReg, self).__init__(**kwargs)
        self.gru = keras.layers.RNN(myGRU(units))
        self.dense = keras.layers.Dense(1)
    
    def call(self, inputs):
        x = self.gru(inputs)
        x = self.dense(x)
        return x
model = GRUReg(32)
model.compile(optimizer=keras.optimizers.RMSprop(), loss='mae', metrics=['mae'])
model.fit(tf.ones((3, 128, 2)), tf.zeros(3), epochs=5, batch_size=1)