Skip to content
This repository has been archived by the owner on Jul 27, 2023. It is now read-only.

How to save models? #101

Open
jminsol opened this issue Nov 12, 2020 · 2 comments
Open

How to save models? #101

jminsol opened this issue Nov 12, 2020 · 2 comments

Comments

@jminsol
Copy link

jminsol commented Nov 12, 2020

Hello, I am a student who just started to learn machine learning. Your codes were really helpful. I'd like save the codes and restore them later when there were different date sets. I am wondering how I can save the models and successfully restore them later?

Here is my code that I've worked on so far.

`
output_predict = minmax.inverse_transform(output_predict)
deep_future = self.anchor(output_predict[:, 0], 0.3)

    weights = tf.Variable(tf.random_normal([modelnn.X, modelnn.hidden_layer]), name='weights')
    biases = tf.Variable(tf.random_normal([moedlnn.X]), name='biases')
    X = tf.Variable(tf.random_normal([modelnn.X]))
    saver  = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my_model', global_step=1000)

    return deep_future[-test_size:]`
@HariniNarasimhan
Copy link

If you have find the best solution to save the model, please let me know!

@nil-andreu
Copy link

nil-andreu commented Feb 18, 2023

@jminsol @HariniNarasimhan. You can both save the model or save the weights. In the case we want to save the weights (less memory usage):

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model
model.fit(...)

# And we want to load the weights used
model = create_model()
model.load_weights(...)

For documentation here.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants