One of the most flexible ways to train machine learning models is by feeding the training data to the fit function via a python generator. This method has several advantages, it allows to preprocess the data in a customized way for every training loop (e.g. data augmentation), it allows to automatically deal with the batch size and the shuffling of the data for different epochs, and finally, it can be used to use gradually the training dataset if its size does not allow to load everything into the RAM.

Create a data generator using a function

This snippet can be used as a blueprint to create the most suitable data generator for a given project.

import random
import numpy as np

def data_generator(batch_size, data_x, data_y, shuffle=True):

    data_size = len(data_x)
    index_list = [*range(data_size)]
    if shuffle:
    index = 0
    while True:
        batch_x, batch_y = [], []

        for i in range(batch_size):
            if index >= data_size:
                index = 0
                if shuffle:
            index += 1
        yield((batch_x, batch_y))


To show how the generator works we can create a simple dataset where X and Y are ordered lists of alphabet letters, lowercase, and uppercase respectively.

X = [chr(i) for i in np.arange(ord('a'), ord('a')+26)] # lowercase letters
Y = [chr(i) for i in np.arange(ord('A'), ord('A')+26)] # uppercase letter

We can create the data generator with a batch_size of 3:

my_data_gen = data_generator(batch_size=3, data_x=X, data_y=Y, shuffle=True)

The generator can be finally fed to a fit method to train a machine learning model. We can see how it works with a simple for loop:

for i in range(5):
    x, y = next(my_data_gen)
    print(f'Batch number {i}:  x = {x}   y = {y}')


Batch number 0:  x = ['z', 'k', 'n']   y = ['Z', 'K', 'N']
Batch number 1:  x = ['p', 'e', 'j']   y = ['P', 'E', 'J']
Batch number 2:  x = ['x', 'i', 'g']   y = ['X', 'I', 'G']
Batch number 3:  x = ['d', 'v', 'f']   y = ['D', 'V', 'F']
Batch number 4:  x = ['m', 'a', 'q']   y = ['M', 'A', 'Q']

Data gererator tamplete for tensorflow

Fill the ... parts:

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, data, batch_size=32, shuffle=True):
        self.batch_size = batch_size
        self.shuffle = shuffle = data
        self.data_len = ...

    def get_batch_shape(self):
        return (self.batch_size, ...)
    def __len__(self):
        return int(np.floor(self.data_len/ self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X = ...
        y = ...
        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(self.data_len)
        if self.shuffle == True: