Обновление глобальной переменной во время обучения нейронной сети tensorflow

Рейтинг: 0Ответов: 0Опубликовано: 07.03.2023

В настоящий момент, изучаю вариационный автоэнкодер и пытаюсь проверить метод обучения с отжигом KL ошибки. Для данной операции необходимо равномерно увеличивать нормирующий коэффициент бета (β * KL). И я столкнулся с проблемой, у меня имеется код функции Callback

class CustomCallback(keras.callbacks.Callback):
   
   def __init__(self, validation_data = ()):
       super(keras.callbacks.Callback, self).__init__()
       self.X_val, self.y_val = validation_data
       self.log_loss = []
       self.log_val_loss = []
       self.log_reconstruction_loss = []
       self.KL_loss = []
       
   def on_epoch_end(self, epoch, logs={}):
       global beta
#     Глобальная переменная для вычисления коэффциента beta
       beta = tf.cond(beta == 10, lambda: tf.math.reduce_min([beta,10]), lambda: tf.math.reduce_sum([beta,1]))
       tf.print(beta)
#         beta = tf.math.reduce_sum([beta,1]).numpy()

Данная функция вызывается в конце обучения. Для теста я увеличиваю значение beta таким способом

beta = tf.cond(beta == 10, lambda: tf.math.reduce_min([beta,10]), lambda: tf.math.reduce_sum([beta,1]))

Таким образом я могу контролировать, что значение увеличивается после окончания обучения.

Данная переменная используется в функции def train_step(self, data): для нормировки значения KL

#   Обучение нейронной сети
    @tf.function
    def train_step(self, data):
        global beta
        data_len = tf.shape(data)
        data_len = tf.cast(data_len[0], tf.float32)
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
            total_loss = reconstruction_loss + beta * kl_loss
            tf.print(beta)

Но возникает проблема, значение beta в функции def train_step(self, data): не увеличивается. Как можно реализовать обновление значения при обучении.

Также, значение beta увеличивается, но не изменяется в функции def train_step(self, data):

Ответы

Ответов пока нет.