Tensorflow checkpoints in training model

Hello everyone,

this is part of the code for a cycleGAN model that I have implemented, and it is the part related to training

Training cycleGAN:
# cycleGAN architecture

def cyclegan(input_A, input_B):
    # fake images generation
    BfromA = generateB(input_A, training = True)
    AfromB = generateA(input_B, training = True)
    # images recostruction
    regenAfromB = generateA(BfromA, training = True)
    regenBfromA = generateB(AfromB, training = True)

    # auto-generating
    gen_orig_A = generateA(input_A, training = True)
    gen_orig_B = generateB(input_B, training = True)
    # auto-validating
    valid_A = discriminateA(input_A, training = True)
    valid_B = discriminateB(input_B, training = True)
    # fake images validating
    valid_AfromB = discriminateA(AfromB, training = True)
    valid_BfromA = discriminateB(BfromA, training = True)
    return regenAfromB, regenBfromA, gen_orig_A, gen_orig_B, valid_A, valid_B, valid_AfromB, valid_BfromA

# Loss Functions - Optimizers

def generator_loss(generated):
    return tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

def discriminator_loss(real, generated):
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True,
                                                        reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

def cycle_loss(real, generated, LAMBDA):
    c_loss = tf.reduce_mean(tf.abs(real - generated))
    return LAMBDA * c_loss

def identity_loss(real, same, LAMBDA):
    i_loss = tf.reduce_mean(tf.abs(real - same))

    return LAMBDA * i_loss

gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

# Training session

generateA = generator()
discriminateA = discriminator()
generateB = generator()
discriminateB = discriminator()

inputA = tf.keras.layers.Input(shape = [HEIGHT, WIDTH, CHANNEL])
inputB = tf.keras.layers.Input(shape = [HEIGHT, WIDTH, CHANNEL])

def train_step(inputA, inputB):
    with tf.GradientTape(persistent = True) as tape:
        regenA, regenB, gen_origA, gen_origB, disc_A, disc_B, disc_AfB, disc_BfA = cyclegan(inputA, inputB)
        A_gen_loss = generator_loss(disc_AfB)
        B_gen_loss = generator_loss(disc_BfA)
        total_cycle_loss = cycle_loss(inputA, regenA, LAMBDA) + cycle_loss(inputB, regenB, LAMBDA)
        A_identity_loss = identity_loss(inputA, gen_origA, LAMBDA)
        B_identity_loss = identity_loss(inputB, gen_origB, LAMBDA)
        total_A_gen_loss = A_gen_loss + total_cycle_loss + A_identity_loss
        total_B_gen_loss = B_gen_loss + total_cycle_loss + B_identity_loss
        A_disc_loss = discriminator_loss(disc_A, disc_AfB)
        B_disc_loss = discriminator_loss(disc_B, disc_BfA)

    # Gradients and optimizers
    A_generator_gradients = tape.gradient(total_A_gen_loss, generateA.trainable_variables)
    gen_optimizer.apply_gradients(zip(A_generator_gradients, generateA.trainable_variables))

    B_generator_gradients = tape.gradient(total_B_gen_loss, generateB.trainable_variables)
    gen_optimizer.apply_gradients(zip(B_generator_gradients, generateB.trainable_variables))
    A_discriminator_gradients = tape.gradient( A_disc_loss, discriminateA.trainable_variables)
    disc_optimizer.apply_gradients(zip(A_discriminator_gradients, discriminateA.trainable_variables))

    B_discriminator_gradients = tape.gradient(B_disc_loss, discriminateB.trainable_variables)
    disc_optimizer.apply_gradients(zip(B_discriminator_gradients, discriminateB.trainable_variables))

# Training
def train(train_ds, epochs):
    for epoch in range(epochs):
        start = time.time()
        print("Starting epoch", epoch + 1)

        for image_x, image_y in train_ds:
            train_step(image_x.numpy(), image_y.numpy())
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
        save_step(input_path_A, sample_img, epoch, 'P', generateB, step_path)

I need to use Tensorflow checkpoints to train the model in multiple runs, but I have no idea how to incorporate them. I haven't used functions like fit(), model(), compile()...

Would anyone be able to help me?
Hi there,

Tensorflow checkpoints are a useful tool for saving and restoring the state of your model during training. They allow you to save the weights and other parameters of your model at certain checkpoints, so that you can resume training from that point if needed.

To incorporate checkpoints into your code, you can use the tf.train.Checkpoint class. First, you need to define which variables you want to save as checkpoints. In your case, it looks like you would want to save the weights and other parameters of your generator and discriminator models. You can do this by creating an instance of the Checkpoint class and passing in the variables you want to save as arguments. For example:

checkpoint = tf.train.Checkpoint(generator=generateA,

You can repeat this process for your other models as well.

Next, you need to decide at which points during training you want to save the checkpoints. This is usually done at the end of each epoch, but you can also choose to save them at other intervals if needed. To save the checkpoint, you can call the save() method on your checkpoint object, passing in the path where you want to save the checkpoint. For example:


To restore a checkpoint, you can use the restore() method on your checkpoint object, passing in the path to the saved checkpoint. For example:


You can also use the restore() method to load the weights and other parameters from a previous run if you need to resume training from a specific point.

I hope this helps! Let me know if you have any other questions.

