5

I am using PyTorch Lightning version 1.4.0 and have defined the following class for the dataset:

class CustomTrainDataset(Dataset):
    '''
    Custom PyTorch Dataset for training
    
    Args:
        data (pd.DataFrame) - DF containing product info (and maybe also ratings)
        all_itemIds (list) - Python3 list containing all Item IDs
    '''
    
    def __init__(self, data, all_orderIds):
        self.users, self.items, self.labels = self.get_dataset(data, all_orderIds)
    
    def __len__(self):
        return len(self.users)
  
    def __getitem__(self, idx):
        return self.users[idx], self.items[idx], self.labels[idx]
    
    def get_dataset(self, data, all_orderIds):
        users, items, labels = [], [], []
        user_item_set = set(zip(train_ratings['CustomerID'], train_ratings['ItemCode']))

        num_negatives = 7
        for u, i in user_item_set:
            users.append(u)
            items.append(i)
            labels.append(1)
            for _ in range(num_negatives):
                negative_item = np.random.choice(all_itemIds)
                while (u, negative_item) in user_item_set:
                    negative_item = np.random.choice(all_itemIds)
                users.append(u)
                items.append(negative_item)
                labels.append(0)

        return torch.tensor(users), torch.tensor(items), torch.tensor(labels)

followed by the PL class:

class NCF(pl.LightningModule):
    '''
    Neural Collaborative Filtering (NCF)
    
    Args:
        num_users (int): Number of unique users
        num_items (int): Number of unique items
        data (pd.DataFrame): Dataframe containing the food ratings for training
        all_orderIds (list): List containing all orderIds (train + test)
    '''
    
    def __init__(self, num_users, num_items, data, all_itemIds):
    # def __init__(self, num_users, num_items, ratings, all_movieIds):
        super().__init__()
        self.user_embedding = nn.Embedding(num_embeddings = num_users, embedding_dim = 8)
        # self.user_embedding = nn.Embedding(num_embeddings = num_users, embedding_dim = 10)
        self.item_embedding = nn.Embedding(num_embeddings = num_items, embedding_dim = 8)
        # self.item_embedding = nn.Embedding(num_embeddings = num_items, embedding_dim = 10)
        self.fc1 = nn.Linear(in_features = 16, out_features = 64)
        # self.fc1 = nn.Linear(in_features = 20, out_features = 64)
        self.fc2 = nn.Linear(in_features = 64, out_features = 64)
        self.fc3 = nn.Linear(in_features = 64, out_features = 32)
        self.output = nn.Linear(in_features = 32, out_features = 1)
        self.data = data
        # self.ratings = ratings
        # self.all_movieIds = all_movieIds
        self.all_orderIds = all_orderIds
        
    def forward(self, user_input, item_input):
        
        # Pass through embedding layers
        user_embedded = self.user_embedding(user_input)
        item_embedded = self.item_embedding(item_input)

        # Concat the two embedding layers
        vector = torch.cat([user_embedded, item_embedded], dim = -1)

        # Pass through dense layer
        vector = nn.ReLU()(self.fc1(vector))
        vector = nn.ReLU()(self.fc2(vector))
        vector = nn.ReLU()(self.fc3(vector))

        # Output layer
        pred = nn.Sigmoid()(self.output(vector))

        return pred
    
    def training_step(self, batch, batch_idx):
        user_input, item_input, labels = batch
        predicted_labels = self(user_input, item_input)
        loss = nn.BCELoss()(predicted_labels, labels.view(-1, 1).float())
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    def train_dataloader(self):
        return DataLoader(
            ChupsTrainDataset(
                self.data, self.all_orderIds
            ),
            batch_size = 32, num_workers = 2
            # Google Colab's suggested max number of worker in current
            # system is 2 and not 4.
        )
print(f"num_users = {num_users}, num_items = {num_items} & all_itemIds = {len(all_itemIds)}")
# num_users = 12958, num_items = 511238 & all_itemIds = 9114

# Initialize NCF model-
model = NCF(num_users, num_items, train_ratings, all_itemIds)

trainer = pl.Trainer(
    max_epochs = 75, gpus = 1,
    # max_epochs = 5,
    reload_dataloaders_every_n_epochs = True,
    # reload_dataloaders_every_epoch = True,   # deprecated!
    progress_bar_refresh_rate = 50,
    logger = False, checkpoint_callback = False)

trainer.fit(model)

# Save trained model as a checkpoint-
trainer.save_checkpoint("NCF_Trained.ckpt")

To load the saved checkpoint, I have tried:

trained_model = NCF.load_from_checkpoint(
    "NCF_Trained.ckpt", num_users = num_users,
    num_items = train_ratings, data = train_ratings,
    all_itemIds = all_itemIds)
trained_model = NCF(num_users, num_items, train_ratings, all_orderIds).load_from_checkpoint(checkpoint_path = "NCF_Trained.ckpt")

But these don't seem to work. How do I load this saved checkpoint?

3
  • Can you please clarify what you mean by "dont' seem to work". What exactly are you trying to do after .load_from_checkpoint(...) ? Commented Aug 8, 2021 at 22:35
  • Is this solved ? @Arun Commented Sep 1, 2021 at 18:32
  • Not yet @AyanDas Commented Sep 6, 2021 at 14:41

3 Answers 3

2

As shown in here, load_from_checkpoint is a primary way to load weights in pytorch-lightning and it automatically load hyperparameter used in training. So you do not need to pass params except for overwriting existing ones. My suggestion is to try trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")

Sign up to request clarification or add additional context in comments.

Comments

2

add a line in your init method:

self.save_hyperparameters(logger=False)

Then call

trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")

Comments

0

In my case it was crucial to set the model into the evaluation mode via model.eval(). Otherwise it would produce wrong results.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.