Обновление глобальной переменной во время обучения нейронной сети tensorflow
В настоящий момент, изучаю вариационный автоэнкодер и пытаюсь проверить метод обучения с отжигом KL ошибки. Для данной операции необходимо равномерно увеличивать нормирующий коэффициент бета (β * KL). И я столкнулся с проблемой, у меня имеется код функции Callback
class CustomCallback(keras.callbacks.Callback):
def __init__(self, validation_data = ()):
super(keras.callbacks.Callback, self).__init__()
self.X_val, self.y_val = validation_data
self.log_loss = []
self.log_val_loss = []
self.log_reconstruction_loss = []
self.KL_loss = []
def on_epoch_end(self, epoch, logs={}):
global beta
# Глобальная переменная для вычисления коэффциента beta
beta = tf.cond(beta == 10, lambda: tf.math.reduce_min([beta,10]), lambda: tf.math.reduce_sum([beta,1]))
tf.print(beta)
# beta = tf.math.reduce_sum([beta,1]).numpy()
Данная функция вызывается в конце обучения. Для теста я увеличиваю значение beta таким способом
beta = tf.cond(beta == 10, lambda: tf.math.reduce_min([beta,10]), lambda: tf.math.reduce_sum([beta,1]))
Таким образом я могу контролировать, что значение увеличивается после окончания обучения.
Данная переменная используется в функции def train_step(self, data):
для нормировки значения KL
# Обучение нейронной сети
@tf.function
def train_step(self, data):
global beta
data_len = tf.shape(data)
data_len = tf.cast(data_len[0], tf.float32)
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
#----------------------------------------------------------------------------------
total_loss = reconstruction_loss + beta * kl_loss
tf.print(beta)
Но возникает проблема, значение beta в функции def train_step(self, data):
не увеличивается.
Как можно реализовать обновление значения при обучении.
Также, значение beta увеличивается, но не изменяется в функции def train_step(self, data):