1

I have a trained model using LSTM. The model is trained on GPU (On Google COLABORATORY). I have to save the model for inference; which I will run on CPU. Once trained, I saved the model checkpoint as follows:

torch.save({'model_state_dict': model.state_dict()},'lstmmodelgpu.tar')

And, for inference, I loaded the model as :

# model definition
vocab_size = len(vocab_to_int)+1 
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2

model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)

# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

But, it is raising the following error:

model.load_state_dict(checkpoint['model_state_dict'])
  File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
    Missing key(s) in state_dict: "embedding.weight". 
    Unexpected key(s) in state_dict: "encoder.weight".

Is there anything I missed while saving the checkpoint?

3
  • Are you using DataParallel? Commented Jan 19, 2019 at 5:03
  • @harshit_k, No, I am not using DataParallel. I have followed the tutorial : pytorch.org/tutorials/beginner/… Commented Jan 19, 2019 at 8:41
  • Don't know why the error but you may still try the solution given in discuss.pytorch.org/t/… or use model.module.load_state_dict(checkpoint['model_state_dict']). But, I'm not certain any of this would work. Commented Jan 19, 2019 at 9:20

1 Answer 1

3

There are two things to be considered here.

  1. You mentioned that you're training your model on GPU and using it for inference on CPU, so u need to add a parameter map_location in load function passing torch.device('cpu').

  2. There is a mismatch of state_dict keys (indicated in your ouput message), which might be caused by some missing keys or having more keys in state_dict you are loading than the model u are using currently. And for it you have to add a parameter strict with value False in the load_state_dict function. This will make method to ignore the mismatch of keys.

Side note : Try to use extension of pt or pth for checkpoint files as it is a convention .

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.