4

The following code runs a Sequential Keras model, pretty straight forward , on the MNIST data that are packaged with Keras.

In running the following piece of code I get an exception.

The code is readily reproducible.

import tensorflow as tf

class myCallback(tf.keras.callbacks.Callback):
      def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc')>0.99):
          print("\nReached 99% accuracy so cancelling training!")
          self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

callbacks = myCallback()

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

The exception is:

Epoch 1/10
59296/60000 [============================>.] - ETA: 0s - loss: 0.2005 - accuracy: 0.9400

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-f5e673b24d24> in <module>()
     23               metrics=['accuracy'])
     24 
---> 25 model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

C:\Program Files (x86)\Microsoft Visual Studio\Shared\Anaconda3_64\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    871           validation_steps=validation_steps,
    872           validation_freq=validation_freq,
--> 873           steps_name='steps_per_epoch')
    874 
    875   def evaluate(self,

C:\Program Files (x86)\Microsoft Visual Studio\Shared\Anaconda3_64\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    406     if mode == ModeKeys.TRAIN:
    407       # Epochs only apply to `fit`.
--> 408       callbacks.on_epoch_end(epoch, epoch_logs)
    409     progbar.on_epoch_end(epoch, epoch_logs)
    410 

C:\Program Files (x86)\Microsoft Visual Studio\Shared\Anaconda3_64\lib\site-packages\tensorflow\python\keras\callbacks.py in on_epoch_end(self, epoch, logs)
    288     logs = logs or {}
    289     for callback in self.callbacks:
--> 290       callback.on_epoch_end(epoch, logs)
    291 
    292   def on_train_batch_begin(self, batch, logs=None):

<ipython-input-26-f5e673b24d24> in on_epoch_end(self, epoch, logs)
      3 class myCallback(tf.keras.callbacks.Callback):
      4       def on_epoch_end(self, epoch, logs={}):
----> 5         if(logs.get('acc')>0.99):
      6           print("\nReached 99% accuracy so cancelling training!")
      7           self.model.stop_training = True

TypeError: '>' not supported between instances of 'NoneType' and 'float'
4
  • 1
    Have you tried if(logs.get('accuracy')>0.99):? Commented Jun 3, 2019 at 8:25
  • 2
    When creating a callback , if we need an accuracy threshold for training , previous TF versions have logs.get('acc') but in this version we need to use logs.get('accuracy') for it to work. There wasn't any documentation regarding this change Commented Jun 21, 2020 at 16:26
  • The error is because logs.get('acc') must match metric value in model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) Commented Aug 18, 2020 at 1:55
  • 1
    @amalik2205 you're comment is very on point! +1. This also happened to me! Commented Oct 6, 2021 at 22:00

9 Answers 9

7

In model.compile function you defined metrics=['accuracy']. You need to use 'accuracy' in logs.get i.e logs.get('accuracy').

Sign up to request clarification or add additional context in comments.

Comments

2

In Jupyter notebook I had to use "acc", but in google Colab "accuracy" instead. I guess it depends on the tensorflow version installed.

1 Comment

same here, not sure why that is the case but this solution worked for me. Maybe there are different versions installed in Colab compared to the notebook.
1

It's just that with the upgrade of tensorflow to version 2.x the dictionary tag 'acc' has been changed to 'accuracy' therefore replacing the line 5 as follows should do the trick!

if(logs.get('accuracy')>0.99):

Comments

1

Just change logs.get('accuracy') --> logs.get('acc'). It should work fine!

Comments

1

I had the same problem. I changed it to "acc" and it worked like a charm. I made the following changes.

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

And in the callback,

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get("acc") >= 0.99):
            print("Reached 99% accuracy so cancelling training!")
            self.model.stop_training = True

Comments

0

I think it can come from the way you call your function :

If your function is

class myCallback(tf.keras.callbacks.Callback):
...

It should be called like that :

model.fit(x_train, y_train, epochs=10, callbacks=[myCallback()])

Comments

0

probably you are using tensorflow 1., so you may try: if(logs.get('acc')>0.998) andmetrics=['acc']

Comments

-2

The problem is logs.get('acc')>0.99. On your end logs.get('acc') is None from some reason.

Just execute:

None>0.99 and you will get the same error. You probable migrated your code from Python 2 where this would actually work :).

You can simple tweak this with

if(logs.get('acc') is None): # in this case you cannot compare...

Or you can use try: ... except: blocks.

BTW, the same code works just fine on my end.

Comments

-2

For some reason I did ['acc'] in the callback class, with ['accuracy'] in the metrics and it worked.

enter image description here

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.