5

I have trained a tensorflow model to predict the next word for an input text. I saved it as an .h5 file.

I can use that model in another python code to predict word as follows:

import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.models import load_model

model = load_model('model.h5')
model.compile(
    loss = "categorical_crossentropy",
    optimizer = "adam",
    metrics = ["accuracy"]
)

data = open("dataset.txt").read()
corpus = data.lower().split("\n")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)

seed_text = input()

sequence_text = tokenizer.texts_to_sequences([seed_text])[0]
padded_sequence = np.array(pad_sequences([sequence_text], maxlen = 11 -1))
predicted = np.argmax(model.predict(padded_sequence))

Is there a way through which I can directly use that model inside flutter, where I can take input from TextField() and by pressing the button, display the predicted word??

2 Answers 2

3

Steps

  1. Convert the Model into a .tflite model.
# https://www.tensorflow.org/lite/convert/#convert_a_savedmodel_recommended_

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)
  1. Add the tflite Model to the App directory. I usually add the model in a assets/ directory.
android/
assets/
    model.tflite
ios/
lib/
  1. Add tflite as a dependency to pubspec.yaml
dependencies:
  flutter:
    sdk: flutter
  tflite: ^1.0.5
  .
  .
  1. Run Inference in your dart script. For example, the following code snippet is an example script on how to run Inference on an Image where labels.txt is a text file containing the classes:
import 'package:tflite/tflite.dart';
.
.
.

class _MyAppState extends State<MyApp> {
  . . .
  @override
  void initState() {
    super.initState();
    _loading = true;

    loadModel().then((value) {
      setState(() {
        _loading = false;
      });
    });
  }

  classifyImage(File image) async {
    var output = await Tflite.runModelOnImage(
      path: image.path,
      numResults: 2,
      threshold: 0.5,
      imageMean: 127.5,
      imageStd: 127.5,
    );
    setState(() {
      _loading = false;
      _outputs = output;
    });
  }

  loadModel() async {
    await Tflite.loadModel(
      model: "assets/model_unquant.tflite",
      labels: "assets/labels.txt",
    );
  }
  @override
  void dispose() {
    Tflite.close();
    super.dispose();
  }
 . . .
}


SideNote

The tflite plugin doesn't support Text Classification AFAIK, if you want to specifically do Text Classification I'd recommend using the tflite_flutter plugin. Below is the link for a article using the plugin for Text Classification.

Text Classification using TensorFlow Lite Plugin for Flutter

Sign up to request clarification or add additional context in comments.

1 Comment

assets/model_unquant.tflite is it assets/model.tflite in loadModel method.
2

You cannot use a .h5 file directly in Flutter. You will need to either convert it into a .tflite file and use that or create a REST API.

Converting it to a .tflite file is the easiest. You can check the following article for more details: https://medium.com/analytics-vidhya/run-cnn-model-in-flutter-10c944cadcba

If you want to create a REST API, checkout this article: https://medium.com/analytics-vidhya/deploy-ml-models-using-flask-as-rest-api-and-access-via-flutter-app-7ce63d5c1f3b

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.