The story: I am facing an issue with large data that using model.predict + a result off-loading callback (to write the results to disk) is giving a much better performance compared to looping over batches and using model.predict_on_batch() or model(input, training=False) with offloading operation. However, model.predict is accumulating the results and causes out-of-memory error.

The request I was wondering if you could add an argument to model.predict that would prevent it from accumulating the model predictions, allowing one to use a callback to handle the results.

Comment From: sonali-kumari1

Hi @ErfanMowlaei - Thanks for reporting this issue! Could you please share a minimal reproducible code to replicate the performance issue and out-of-memory error?

Comment From: ErfanMowlaei

Hi @sonali-kumari1 -

Thank you for your attention. You could use the following code:

import tensorflow as tf
import numpy as np
import time

def build_model(seq_len=1_000_000, channels=3, filters=128):
    return tf.keras.Sequential([
        tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu', input_shape=(seq_len, channels)),
        tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu'),
        tf.keras.layers.Conv1D(filters, kernel_size=7, padding='same', activation='relu'),
    ])


def generate_sequences(n, seq_len=1_000_000, channels=3, seed=42):
    rng = np.random.default_rng(seed)
    return rng.standard_normal(size=(n, seq_len, channels)).astype(np.float32)

if __name__ == "__main__":
    # Parameters
    n_samples = 50
    seq_len = 1_000_000
    channels = 3

    # Model and data
    model = build_model(seq_len, channels)
    data = generate_sequences(n_samples, seq_len, channels)

    # tf.data.Dataset pipeline
    dataset = tf.data.Dataset.from_tensor_slices(data).batch(1)

    # warm up
    model.predict_on_batch(next(iter(dataset)))
    # Benchmark model.predict_on_batch
    start = time.time()
    for batch in dataset:
        model.predict_on_batch(batch)
    end = time.time()
    print(f"model.predict_on_batch (manual batch): {end - start:.2f} seconds", flush=True)

    # Benchmark model.predict
    model.predict(dataset)

Here's the output:

model.predict_on_batch (manual batch): 7.00 seconds 50/50 ━━━━━━━━━━━━━━━━━━━━ 4s 84ms/step

If you increase the number of samples (e.g., 200), model.predict(...) causes OOM error on an NVIDIA A100 80GB GPU because of prediction accumulation.