Skip to content

In Progress Implementation of GATO style Generalist Multimodal model capable of image, text, RL and Robotics tasks

License

Notifications You must be signed in to change notification settings

ManifoldRG/NEKO

Repository files navigation

NEKO 🐈

Check out our Discord Here

Status

This implementation is currently in progress.

Vision

The NEKO Project is an open source effort to build a "generalist" model of greater scale and capability as that reported in DeepMind’s 2022 Paper, A Generalist Agent. This constitutes the first major step in a longer goal of building multimodal, multiobjective models that work well across a variety of domains.

If you use this project, we'd love for you to refer back to Manifold in your code!

Setup

Manual Setup

conda env create -f env.yml 

Our code works in colab, a minimal example can be found here.

Minari

We rely on Minari to provide a standard for datasets. However, we are currently testing with MuJoCo locomotion and Atari tasks, which are not included in Minari by default. These datasets are generated by: https://github.com/daniellawson9999/data-tests/ and can be downloaded by:

cd ..
python ./gato/data/download_custom_datasets.py

(this will only download MuJoCo datasets, but refer to file for downloading others like Breakout)

Docker

docker build -t gato-control -f ./docker/Dockerfile .
docker run -it --mount "type=bind,source=$(pwd),target=/app/gato-control" --entrypoint /bin/bash --gpus=all gato-control

Training

Below are some example training commands.

Training on Text Dataset (e.g. wikitext from huggingface):

python train.py --embed_dim=768 --layers=6 --heads=24 --training_steps=1000 --log_eval_freq=10 --warmup_steps=20 --batch_size=16 --sequence_length=1024 --eval_episodes=10 --activation_fn=gelu --save_model --save_mode=checkpoint --text_prop=1.0 --eval_text_log_examples --text_datasets=wikitext-2-v1 --text_datasets_paths=wikitext --use_wandb --pretrained_lm=gpt2 --disable_cosine_decay

example run log: https://wandb.ai/bhavul/gato-control/runs/jgqxfzxn/overview?workspace=user-bhavul

Training on 3 MuJoCo locomotion tasks:

python train.py --embed_dim=768 --layers=6 --heads=24 --training_steps=100000 --log_eval_freq=10000 --warmup_steps=10000 --batch_size=32 -k=240 --eval_episodes=10 --activation_fn=gelu --save_model --save_mode=checkpoint --control_datasets d4rl_halfcheetah-expert-v2 d4rl_hopper-expert-v2 d4rl_walker2d-expert-v2 -w

example run log: https://wandb.ai/daniellawson9999/gato-control/runs/j9u26q9p/overview?workspace=user-daniellawson9999

Atari (in progress):

python train.py --embed_dim=128 --layers=3 --heads=1 --training_steps=10000 --log_eval_freq=1 --warmup_steps=100 --batch_size=4 -k=512 --eval_episodes=1 --device=cuda --control_datasets Breakout-top1-s1-v0

example run log: https://wandb.ai/daniellawson9999/gato-control/runs/qagorj06/workspace?workspace=user-daniellawson9999\

In general, control_datasets can contain lists of any strings in download_custom_datasets.py or a dataset in https://minari.farama.org/ with Box or Discrete observation or action spaces, although not all default Minari environments have not been tested yet. can mix in a single run, e.g: --control_datasets Breakout-top1-s1-v0 hammer-expert-v0

Training Image-Caption (in progress):

python train.py --use_wandb --embed_dim=768 --layers=6 --heads=24 --training_steps=1000 --log_eval_freq=10 --warmup_steps=10 --batch_size=4 -k=240 --eval_episodes=10 --sequence_length=1024 --activation_fn=gelu --save_model --caption_prop=1.0 --caption_dataset="/<your data path>/Caption_Data" --caption_train_data=train --caption_test_data=test

Training VQA (in progress):

python train.py --embed_dim=768 --layers=6 --heads=24 --training_steps=1000 --log_eval_freq=10 --warmup_steps=10 --batch_size=4 -k=240 --eval_episodes=10 --sequence_length=1024 --activation_fn=gelu --save_model --vqa_prop=1.0 --vqa_dataset='/<your data path>/VQA_Data/' --vqa_train_data=train2014 --vqa_test_data=val2014 --train_img_name_prefix=COCO_train2014_ --train_img_file_name_len=27 --test_img_name_prefix=COCO_val2014_ --test_img_file_name_len=25

The --caption_prop and --vqa_prop are the proportions of samples of data from each of the two tasks (cation and VQA) that are used for the model. Such proportions from all tasks (control tasks such Atari, and non-control tasks, such as text, image-caption, VQA) should sum up to 1.0 if multiple tasks are trained simultaneously, which should be the case for normal training. The above-mentioned examples single out each task for demo and test purpose.

The Image-Caption and VQA tasks can be tested on Colab, we have a few Colab Notebooks for that purpose in the NEKO/misc folder

Atari Datasets

All Atari datasets now follow the convention of {Name}-top1-s1-v0, e.g. Breakout-top1-s1-v0. Previously, we old runs may have Breakout-expert_s0-v0 which is depreciated. These datasets are top-1% dqn-replay converted to Minari, refer here for more details.

You will be able to train on any env in https://github.com/ManifoldRG/gato-control/blob/master/gato/envs/atari.py. To train on all 40 training games, pass --datasets TOP1_ATARI_TRAIN or --datasets TOP1_ATARI_TEST for the 5 testing environmments.

Currently, only Breakout is provided here for testing but others will be available shortly.

Image-Caption Datasets

So far we have identified two datasets, and the number can increase in the future. For both datasets, we have used a tool "img2dataset" to download the data into webdataset format -

  • Data are downloaded into .tar files, each .tar file contains multiple bundles
  • Each bundle contains: One image in jpg format resized to the designated size (256*256 by default) One txt file that is the caption for the image One .json file that is the metadata for this bundle (the URL of the image, the caption, the image size, etc.)

At the time when the Image-Caption task is instantiated, it processes the downlaoded data into the format that can be accepted by the model for training. So far the task can only process data in webdataset format. As more data sources are identified, different methods to process data may be added when necessary

The two datasets for Image-Caption task:

VQA Datasets

So far, we have identified one dataset: https://okvqa.allenai.org/download.html, follow the instruction to download. Each download includes a questions json file and an annotations json files to list the image IDs and the questions and their answers associated with each image ID, and the image files with the image IDs as part of the image file names.

At the time when the VQA task is instantiated, it processes the downloaded data into the format that can be accepted by the model for training. So far the task can only process data in this specific format. As more data sources are identified, different methods to process data may be added when necessary

Evaluation

python eval.py --model_path={model_path} --eval_episodes={n_episodes}

Examples

import torch
from gato.policy.gato_policy import GatoPolicy

model = GatoPolicy(
        device='cpu',
        embed_dim=128,
        layers=2,
        heads=4,
        dropout=0.1,
)

# This computes logits and (cross-entropy) loss over a batch of size three, where each diciontary is an episode in the batch

logits, loss = model([
    {
        'images': torch.randn(20, 3, 80, 64),
        'discrete_actions': torch.randint(0, 55, (20, 1)),
    },
    {
        'continuous_obs': torch.randn(15, 8),
        'continuous_actions': torch.randn(15, 4),
    },
    {
        'images': torch.randn(100, 3, 224, 224),
        'continuous_actions': torch.randn(100, 11),
    }
], compute_loss=True)

Pretrained models

We provide some pretrained models, which are not geared for high-performance or reliable external use, but to aid in our open-source development. These can be found for 3 MuJoCo tasks and Breakout. Other models may be added here where the directory contains checkpoints, arguments, and link to WandB run in each info.txt.

Future

Our implementation does not directly mirror Gato. Features left out or planned to be added in the future can be found in todo.md. We are working on adding modular tasks, check out the Issues tab.

Credits

This implementation is influenced and uses components from: