The story: I am facing an issue with large data that using model.predict + a result off-loading callback (to write the results to disk) is giving a much better performance compared to looping over batches and using model.predict_on_batch() or model(input, training=False) with offloading operation. However, model.predict is accumulating the results and causes out-of-memory error.
The request I was wondering if you could add an argument to model.predict that would prevent it from accumulating the model predictions, allowing one to use a callback to handle the results.
Comment From: sonali-kumari1
Hi @ErfanMowlaei - Thanks for reporting this issue! Could you please share a minimal reproducible code to replicate the performance issue and out-of-memory error?
Comment From: ErfanMowlaei
Hi @sonali-kumari1 -
Thank you for your attention. You could use the following code:
import tensorflow as tf
import numpy as np
import time
def build_model(seq_len=1_000_000, channels=3, filters=128):
return tf.keras.Sequential([
tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu', input_shape=(seq_len, channels)),
tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu'),
tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu'),
])
def generate_sequences(n, seq_len=1_000_000, channels=3, seed=42):
rng = np.random.default_rng(seed)
return rng.standard_normal(size=(n, seq_len, channels)).astype(np.float32)
if __name__ == "__main__":
# Parameters
n_samples = 50
seq_len = 1_000_000
channels = 3
# Model and data
model = build_model(seq_len, channels)
data = generate_sequences(n_samples, seq_len, channels)
# tf.data.Dataset pipeline
dataset = tf.data.Dataset.from_tensor_slices(data).batch(1)
# warm up
model.predict_on_batch(next(iter(dataset)))
# Benchmark model.predict_on_batch
start = time.time()
for batch in dataset:
model.predict_on_batch(batch)
end = time.time()
print(f"model.predict_on_batch (manual batch): {end - start:.2f} seconds", flush=True)
# Benchmark model.predict
model.predict(dataset)
Here's the output:
model.predict_on_batch (manual batch): 7.00 seconds 50/50 ━━━━━━━━━━━━━━━━━━━━ 4s 84ms/step
If you increase the number of samples (e.g., 200), model.predict(...) causes OOM error on an NVIDIA A100 80GB GPU because of prediction accumulation.