The story: I am facing an issue with large data that using model.predict + a result off-loading callback (to write the results to disk) is giving a much better performance compared to looping over batches and using model.predict_on_batch() or model(input, training=False) with offloading operation. However, model.predict is accumulating the results and causes out-of-memory error.
The request I was wondering if you could add an argument to model.predict that would prevent it from accumulating the model predictions, allowing one to use a callback to handle the results.
Comment From: sonali-kumari1
Hi @ErfanMowlaei - Thanks for reporting this issue! Could you please share a minimal reproducible code to replicate the performance issue and out-of-memory error?
Comment From: ErfanMowlaei
Hi @sonali-kumari1 -
Thank you for your attention. You could use the following code:
import tensorflow as tf
import numpy as np
import time
def build_model(seq_len=1_000_000, channels=3, filters=128):
return tf.keras.Sequential([
tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu', input_shape=(seq_len, channels)),
tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu'),
tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu'),
])
def generate_sequences(n, seq_len=1_000_000, channels=3, seed=42):
rng = np.random.default_rng(seed)
return rng.standard_normal(size=(n, seq_len, channels)).astype(np.float32)
if __name__ == "__main__":
# Parameters
n_samples = 50
seq_len = 1_000_000
channels = 3
# Model and data
model = build_model(seq_len, channels)
data = generate_sequences(n_samples, seq_len, channels)
# tf.data.Dataset pipeline
dataset = tf.data.Dataset.from_tensor_slices(data).batch(1)
# warm up
model.predict_on_batch(next(iter(dataset)))
# Benchmark model.predict_on_batch
start = time.time()
for batch in dataset:
model.predict_on_batch(batch)
end = time.time()
print(f"model.predict_on_batch (manual batch): {end - start:.2f} seconds", flush=True)
# Benchmark model.predict
model.predict(dataset)
Here's the output:
model.predict_on_batch (manual batch): 7.00 seconds 50/50 ━━━━━━━━━━━━━━━━━━━━ 4s 84ms/step
If you increase the number of samples (e.g., 200), model.predict(...) causes OOM error (depending on how much RAM you have).
Comment From: sonali-kumari1
@ErfanMowlaei -
I have tested your code with the latest version of Keras(3.11.3) on a T4 GPU and encountered a session crash, likely due to memory exhaustion caused by model.predict(). We will look into this and update you soon. Thanks!
Comment From: ErfanMowlaei
@sonali-kumari1 Thank you and the cool Keras dev team!
Comment From: pctablet505
@ErfanMowlaei https://github.com/pctablet505/keras/tree/non-cumulative-model.predict This fixes the issue. You can temporarily try this branch.
https://colab.research.google.com/gist/pctablet505/cbf7c3f97d3e404b5786da6b57fa4e42/non-cumulative-model-predict.ipynb From my testing, I see no performance gains when using model.predict instead of predict_on_batch.
Comment From: ErfanMowlaei
@pctablet505 Thanks. There was an issue with your branch that I left a comment on. Please check it out. Also, did you verify if by setting accumulate to False, we can still retrieve the predictions in the callbacks?
About the gist, I think you are not paying attention there. The manual time (predict_on_batch) in your gist is printed as 142s and the predict timing is 66s.
Comment From: pctablet505
I've corrected the code, and verified the callbacks. https://colab.research.google.com/gist/pctablet505/987ff6b69c7adce8887fd9e5919c59d5/non-cumulative-model-predict.ipynb
Comment From: ErfanMowlaei
@pctablet505 Great, thank you very much! I hope to see it in the next patch!