This is a series of experiments I did about Doodle Classifier(a Convolutional Neural Network) using tensorflow.js and tensorflow. The data I used is from Quickdraw dataset.
Here are a list of the projects -
- Train a doodle classifier with tf.js
- Train a doodle classifier with 345 classes
- KNN doodle classifier
Credits: Big thanks to @zaidalyafeai's sketcher google colab for training.
I trained a doodle classifier with 3 classes(bowtie, lollipop, rainbow) in the browser using tfjs' layers API and tf.js-vis. The code is based on tf.js Example - Training MNIST.
Try a live demo here.
Once you open the webpage, wait until the page load the data, train the model, evaluate the model. It will download two files: myDoodleNet.json
and myDoodleNet.weights.bin
. To test this model your self, you can load these two files back, and click on 'load model' button, then draw sth on the canvas, hit 'Guess' button to let model start guessing the drawing.
It's trained on all 345 categories from Quickdraw dataset, 50k images per class. It's trained with tensorflow, and ported to tf.js in the browser. Here is the training notebook.
This notebook is heavily based on @zaidalyafeai's Sketcher notebook on 100 classes. I expanded the data to 345 classes and added a few layers to improve the accurary on 345 classes.
I used spell.run's remote GPU machine with big RAM to load all the data and train the model.
Try a live demo here.
Based on the previous doodle classifier of 345 classes, I added KNN classifier to it, so people can customize their own doodle classes.
Try a live demo here.
You can draw 10+ circles and add them to class A, and draw 10+ lines and add them to class B, then let the model to guess your new drawing. You can define any other classes, it doesn't need to be circles or squares.
To run each examples locally, open your terminal, type in the following commands:
$ git clone https://github.com/yining1023/doodleNet.git
$ cd doodleNet
$ python -m SimpleHTTPServer # $ python3 -m http.server (if you are using python 3)
Go to localhost:8000/demo
in your browser, you will see a directory list like this:
- DoodleClassifier_345/
- DoodleClassifier_KNN/
- TrainDoodleClassifier/
Click into each example to see the demo.