Skip to content

LeoSouquet/Bert4Rec_Pytorch

Repository files navigation

Implementation of BERT4REC

This repo contains a version of te model BERT4REC.

Original Paper: BERT4Rec: Sequential Recommendation with BERT (Sun et al.)

Install

Clone repo and install requirements.txt

git clone https://github.com/LeoCyclope/BERT4Rec  # clone
cd BERT4Rec
pip install -r requirements.txt  # install

Google Colab

This repo can be tested using the following notebook. As this repo is private, you will need to configure a token to clone the repo:

Google Colab Notebook

Inference

The script download_weights.sh will only download the weights locally.

Requirements: AWS CLI is required to download the weights. (https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html)

bash download_weights.sh

To perform the inference

python inference.py

You can change the path of the model in the file: common_utils/metadata.py in the variable init_weights

Disclaimer

For the purpose of this Demo, the model is stored on S3 with public access, models should have restricted access. Either through IAM, ACLs or VPN

Training

Data

Data Preparation

For the purpose of this demo, data are the Tensorflow records in the folder "Data" with the .tfrecord.

You can create your own Dataset class to integrate you own data.

On the first loading, data will be cached in a data.cache in the same folder as the data files for faster loading. If data have changed, you can manually removed the "data.cache file.

Disclaimer

For the purpose of this Demo, the full data set has been used for training. In a rel context, time should be spent on building a validation set that fits the project requirements.

Launching Training
python training.py

Testing

run the command:` python -m unittest tests.test_bert4rec`

Additional Information

The Vocabular size has been set to the maximum number of films in the movie_title_by_index.json file: 40857 However, after parsing the data record, only 28221 are listed in the data.

  • 1
  • 2

Sources Used: