Thursday, November 26, 2020

Best Practices with Tensorflow: Test All Your Inputs Early

Tensorflow model training can take a long time. Depending on the machine, even the first epoch can take many hours to run. In this case, the remaining arguments like validation_data are only evaluated much later. Incorrectly formed input will be detected many hours later, deep in the stack, with incomprehensible errors. By this time, you might have forgotten what the shapes of the input or their expected data type is. You want to detect failures early, when you run model.fit() rather than a day later.

One easy way to avoid this is to wrap all model.fit() calls in a single method, which tests itself out with a smaller input. Here's the idea. If you have training data (X_train and y_train), and validation data (X_valid and y_valid), you take the first 10 observations, and call the training method itself.

This run will be fast, because you are training a model with a tiny set of observations, and the validation is done with a tiny set of values too.

It also ensures that any changes to the model.fit() line are correctly reflected both in the test_shapes run and in the real run. Any training done on the model is minor: a real run with the full data and full set of epochs will change the model weights, so this is harmless. In case you are using a pre-trained model, you can clone the model before trying the test run.

If you are using data generators, then the same mechanism can be used. Create a new tfdata object with a few observations both on the training data, and another with the validation data. Pass that to your model training function.

def fit_e9_model(model, X_train, y_train,
                 X_valid, y_valid, epochs,
                 batch_size=32, verbose=0,
                 test_shapes=True):

    # This fails after a day, when the validation data is incorrectly shaped.                              
    # This is a terrible idea. Failures should be early.                                                   

    # The best way to guard against it is to run a small fit run with                                      
    # a tiny data size, and a tiny validation data size to ensure that                                     
    # the data is correctly shaped.                                                                        

    if (test_shapes):
        print ("Testing with the first 10 elements of the input")
        X_small = X_train[:10,:,:,:]
        y_small = y_train[:10]
        X_valid_small = X_valid[:10,:,:,:]
        y_valid_small = y_valid[:10]
        # Call ourselves again with a smaller input. This confirms                                         
        # that the two methods are calling the same .fit() method, and                                     
        # that the input is correctly shaped in the original method too.                                   
        fit_e9_model(model, X_small, y_small,
                     X_valid_small, y_valid_small,
                     epochs=epochs, verbose=verbose,
                     batch_size=5,
                     test_shapes=False)

    # If that worked, then do a full run.                                                                  
    history_conv = model.fit(x=X_train, y=y_train, batch_size=batch_size,
                             validation_data=(X_valid, y_valid),
                             epochs=epochs, verbose=verbose)
    return history_conv



history = fit_e9_model(model, X_train, y_train, X_valid, y_valid, epochs=10)