Сохранение отслеживаемых объектов в TensorFlow через Checkpoint

Рейтинг: 1Ответов: 1Опубликовано: 24.07.2023

Видимо, не совсем понимаю концепцию, как должен работать train.Checkpoint() в TF.

В документации сказано, что Checkpoint служит для сохранения отслеживаемых объектов. Итак, tf.Variables у нас именно такой.

var = tf.Variable([False] * 8, name="var")

saver = tf.train.Checkpoint()
save_path = saver.save("checkpoint")

var[0].assign(True)

saver.restore(save_path)
print(var)

Т.е. у меня есть отслеживаемая переменная var. Это вектор из булевых значений. Я создаю чекпоинт, сохраняю состояние. Далее изменяю переменную, загружаю сохраненное состояние, но переменная остается прежней, с изменениями. Вероятно, очень глупый вопрос, но пытаюсь понять, как это должно работать.

Ответы

▲ 2Принят

Отвечу на свой вопрос спустя 2 месяца так, какой ответ хотелось бы получить тогда. Может кому будет полезно.

При создании чекпоинта мы должны указать, какие переменные мы сохраняем. Т.е. тогда получится:

saver = tf.train.Checkpoint(var=var)
var[0].assign(True)
saver.restore(save_path).assert_consumed()

На практике мы само собой чаще всего сохраняем параметры модели. Это может выглядеть следующим образом:

ckpt = tf.train.Checkpoint(step=tf.Variable(1), W1=model.W1, b1=model.b1, b_out=model.b_out)

Далее мы можем использовать CheckpointManager, где указываем путь до сохранения и максимальное количество сохраненных чекпоинтов.

manager = tf.train.CheckpointManager(ckpt, 'PATH', max_to_keep=3)

step, который мы задали в Checkpoint, можно удобно итерировать по ходу работы алгоритма, и в нужный момент сохранять данные в файл. Например, каждые 100 итераций:

ckpt.step.assign_add(1)
if int(ckpt.step) % 100 == 0:
    save_path = manager.save()

Чтобы восстановить состояние переменных из чекпоинта, используем restore:

ckpt.restore(PATH).assert_consumed()