diff --git a/.all-contributorsrc b/.all-contributorsrc
new file mode 100644
index 000000000..a96940963
--- /dev/null
+++ b/.all-contributorsrc
@@ -0,0 +1,558 @@
+{
+ "files": [
+ "README.md"
+ ],
+ "imageSize": 100,
+ "commit": false,
+ "contributors": [
+ {
+ "login": "shiffman",
+ "name": "Daniel Shiffman",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/191758?v=4",
+ "profile": "http://www.shiffman.net",
+ "contributions": [
+ "code",
+ "example",
+ "projectManagement",
+ "review",
+ "test",
+ "video"
+ ]
+ },
+ {
+ "login": "cvalenzuela",
+ "name": "CristΓ³bal Valenzuela",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/10605821?v=4",
+ "profile": "https://cvalenzuelab.com/",
+ "contributions": [
+ "code",
+ "example",
+ "review",
+ "tool",
+ "test"
+ ]
+ },
+ {
+ "login": "yining1023",
+ "name": "Yining Shi",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/8662372?v=4",
+ "profile": "https://1023.io",
+ "contributions": [
+ "code",
+ "example",
+ "review",
+ "tool",
+ "test",
+ "bug"
+ ]
+ },
+ {
+ "login": "handav",
+ "name": "Hannah Davis",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/1385308?v=4",
+ "profile": "http://www.hannahishere.com",
+ "contributions": [
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "joeyklee",
+ "name": "Joey Lee",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/3622055?v=4",
+ "profile": "https://jk-lee.com/",
+ "contributions": [
+ "code",
+ "example",
+ "review",
+ "content",
+ "test"
+ ]
+ },
+ {
+ "login": "AshleyJaneLewis",
+ "name": "AshleyJaneLewis",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/43127855?v=4",
+ "profile": "https://github.com/AshleyJaneLewis",
+ "contributions": [
+ "blog",
+ "design",
+ "eventOrganizing",
+ "content"
+ ]
+ },
+ {
+ "login": "ellennickles",
+ "name": "Ellen Nickles",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/31713501?v=4",
+ "profile": "https://ellennickles.site/",
+ "contributions": [
+ "blog",
+ "content",
+ "ideas",
+ "tutorial"
+ ]
+ },
+ {
+ "login": "itayniv",
+ "name": "Itay Niv",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/5209486?v=4",
+ "profile": "http://www.itayniv.com",
+ "contributions": [
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "nikitahuggins",
+ "name": "Nikita Huggins",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/18563958?v=4",
+ "profile": "http://nikitahuggins.com",
+ "contributions": [
+ "blog",
+ "content",
+ "ideas"
+ ]
+ },
+ {
+ "login": "AbolTaabol",
+ "name": "Arnab Chakravarty",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/19427655?v=4",
+ "profile": "http://www.arnabchakravarty.com",
+ "contributions": [
+ "content",
+ "userTesting"
+ ]
+ },
+ {
+ "login": "AidanNelson",
+ "name": "Aidan Nelson",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/6486359?v=4",
+ "profile": "http://www.aidanjnelson.com/",
+ "contributions": [
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "WenheLI",
+ "name": "WenheLI",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/23213772?v=4",
+ "profile": "http://portfolio.steins.live",
+ "contributions": [
+ "code",
+ "example",
+ "maintenance",
+ "ideas"
+ ]
+ },
+ {
+ "login": "dariusk",
+ "name": "Darius Kazemi",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/266454?v=4",
+ "profile": "https://tinysubversions.com",
+ "contributions": [
+ "ideas",
+ "question"
+ ]
+ },
+ {
+ "login": "Derek-Wds",
+ "name": "Dingsu Wang",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/26081991?v=4",
+ "profile": "https://wangdingsu.com",
+ "contributions": [
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "garym140",
+ "name": "garym140",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/30574513?v=4",
+ "profile": "https://github.com/garym140",
+ "contributions": [
+ "content",
+ "blog",
+ "ideas",
+ "userTesting"
+ ]
+ },
+ {
+ "login": "genekogan",
+ "name": "Gene Kogan",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/1335251?v=4",
+ "profile": "http://genekogan.com",
+ "contributions": [
+ "code",
+ "example",
+ "ideas"
+ ]
+ },
+ {
+ "login": "hhayley",
+ "name": "Hayley Hwang",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/22086451?v=4",
+ "profile": "http://hhayeon.com",
+ "contributions": [
+ "code",
+ "example",
+ "ideas"
+ ]
+ },
+ {
+ "login": "lisajamhoury",
+ "name": "Lisa Jamhoury",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/7552772?v=4",
+ "profile": "http://lisajamhoury.com",
+ "contributions": [
+ "example",
+ "ideas"
+ ]
+ },
+ {
+ "login": "matamalaortiz",
+ "name": "Alejandro Matamala Ortiz",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/5123955?v=4",
+ "profile": "https://www.matamala.info",
+ "contributions": [
+ "design",
+ "content",
+ "blog"
+ ]
+ },
+ {
+ "login": "mayaman",
+ "name": "Maya Man",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/8224678?v=4",
+ "profile": "http://mayaontheinter.net",
+ "contributions": [
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "MimiOnuoha",
+ "name": "Mimi Onuoha",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/1565846?v=4",
+ "profile": "http://mimionuoha.com",
+ "contributions": [
+ "ideas",
+ "content",
+ "review"
+ ]
+ },
+ {
+ "login": "NHibiki",
+ "name": "Yuuno, Hibiki",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/18514672?v=4",
+ "profile": "https://i.yuuno.cc/",
+ "contributions": [
+ "code",
+ "example",
+ "maintenance"
+ ]
+ },
+ {
+ "login": "oveddan",
+ "name": "Dan Oved",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/891755?v=4",
+ "profile": "http://www.danioved.com/",
+ "contributions": [
+ "code",
+ "example",
+ "question",
+ "ideas"
+ ]
+ },
+ {
+ "login": "stephkoltun",
+ "name": "Stephanie Koltun",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/7053425?v=4",
+ "profile": "http://anothersideproject.co",
+ "contributions": [
+ "code",
+ "example",
+ "content",
+ "blog",
+ "design"
+ ]
+ },
+ {
+ "login": "viztopia",
+ "name": "YG Zhang",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/37890050?v=4",
+ "profile": "https://github.com/viztopia",
+ "contributions": [
+ "code",
+ "example",
+ "ideas"
+ ]
+ },
+ {
+ "login": "wenqili",
+ "name": "Wenqi Li",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/22087042?v=4",
+ "profile": "https://www.wenqi.li",
+ "contributions": [
+ "code",
+ "example",
+ "infra"
+ ]
+ },
+ {
+ "login": "brondle",
+ "name": "Brent Bailey",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/12499678?v=4",
+ "profile": "http://brentlbailey.com",
+ "contributions": [
+ "test",
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "Jonarod",
+ "name": "Jonarod",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/6122703?v=4",
+ "profile": "https://github.com/Jonarod",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "JazzTap",
+ "name": "Jasmine Otto",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/15673619?v=4",
+ "profile": "https://jazztap.github.io",
+ "contributions": [
+ "code",
+ "test",
+ "example"
+ ]
+ },
+ {
+ "login": "zaidalyafeai",
+ "name": "Zaid Alyafeai",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/15667714?v=4",
+ "profile": "https://twitter.com/zaidalyafeai",
+ "contributions": [
+ "code",
+ "example",
+ "ideas",
+ "question"
+ ]
+ },
+ {
+ "login": "AlcaDesign",
+ "name": "Jacob Foster",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/7132646?v=4",
+ "profile": "https://alca.tv",
+ "contributions": [
+ "code",
+ "example",
+ "test"
+ ]
+ },
+ {
+ "login": "memo",
+ "name": "Memo Akten",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/144230?v=4",
+ "profile": "http://www.memo.tv",
+ "contributions": [
+ "code",
+ "example"
+ ]
+ },
+ {
+ "login": "TheHidden1",
+ "name": "Mohamed Amine",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/31354864?v=4",
+ "profile": "https://thehidden1.github.io/",
+ "contributions": [
+ "code",
+ "example",
+ "ideas",
+ "test"
+ ]
+ },
+ {
+ "login": "meiamsome",
+ "name": "Oliver Wright",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/1187491?v=4",
+ "profile": "http://meiamso.me",
+ "contributions": [
+ "code",
+ "test"
+ ]
+ },
+ {
+ "login": "marshalhayes",
+ "name": "Marshal Hayes",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/17213165?v=4",
+ "profile": "https://marshalhayes.dev",
+ "contributions": [
+ "doc"
+ ]
+ },
+ {
+ "login": "reiinakano",
+ "name": "Reiichiro Nakano",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/18363734?v=4",
+ "profile": "https://reiinakano.github.io",
+ "contributions": [
+ "code",
+ "test",
+ "example"
+ ]
+ },
+ {
+ "login": "nsthorat",
+ "name": "Nikhil Thorat",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/1100749?v=4",
+ "profile": "https://deeplearnjs.org/",
+ "contributions": [
+ "code",
+ "example",
+ "ideas",
+ "infra"
+ ]
+ },
+ {
+ "login": "irealva",
+ "name": "Irene Alvarado",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/8978670?v=4",
+ "profile": "http://www.irenealvarado.com",
+ "contributions": [
+ "code",
+ "example",
+ "maintenance",
+ "ideas"
+ ]
+ },
+ {
+ "login": "vndrewlee",
+ "name": "Andrew Lee",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/6391516?v=4",
+ "profile": "http://www.vndrewlee.com/",
+ "contributions": [
+ "code",
+ "example",
+ "ideas"
+ ]
+ },
+ {
+ "login": "fjcamillo",
+ "name": "Jerhone",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/12166244?v=4",
+ "profile": "https://medium.com/@fjcamillo.dev",
+ "contributions": [
+ "doc"
+ ]
+ },
+ {
+ "login": "achimkoh",
+ "name": "achimkoh",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/3594463?v=4",
+ "profile": "https://scalarvectortensor.net/",
+ "contributions": [
+ "code",
+ "example",
+ "test"
+ ]
+ },
+ {
+ "login": "hx2A",
+ "name": "Jim",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/4044283?v=4",
+ "profile": "http://ixora.io",
+ "contributions": [
+ "example",
+ "doc",
+ "content"
+ ]
+ },
+ {
+ "login": "champierre",
+ "name": "Junya Ishihara",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/10215?v=4",
+ "profile": "https://github.com/champierre/resume",
+ "contributions": [
+ "maintenance",
+ "code"
+ ]
+ },
+ {
+ "login": "micuat",
+ "name": "Naoto HIΓDA",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/1835081?v=4",
+ "profile": "http://naotohieda.com",
+ "contributions": [
+ "maintenance"
+ ]
+ },
+ {
+ "login": "montoyamoraga",
+ "name": "aarΓ³n montoya-moraga",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/3926350?v=4",
+ "profile": "http://montoyamoraga.io",
+ "contributions": [
+ "maintenance",
+ "example"
+ ]
+ },
+ {
+ "login": "b2renger",
+ "name": "b2renger",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/1818874?v=4",
+ "profile": "http://b2renger.github.io/",
+ "contributions": [
+ "code",
+ "infra"
+ ]
+ },
+ {
+ "login": "adityaas26",
+ "name": "Aditya Sharma",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/24931529?v=4",
+ "profile": "http://adityasharma.me",
+ "contributions": [
+ "maintenance"
+ ]
+ },
+ {
+ "login": "okuna291",
+ "name": "okuna291",
+ "avatar_url": "https://avatars1.githubusercontent.com/u/5407359?v=4",
+ "profile": "https://github.com/okuna291",
+ "contributions": [
+ "ideas"
+ ]
+ },
+ {
+ "login": "xujenna",
+ "name": "Jenna",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/13280722?v=4",
+ "profile": "http://www.xujenna.com",
+ "contributions": [
+ "ideas"
+ ]
+ },
+ {
+ "login": "nicoleflloyd",
+ "name": "nicoleflloyd",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/35693567?v=4",
+ "profile": "https://github.com/nicoleflloyd",
+ "contributions": [
+ "content",
+ "design",
+ "userTesting"
+ ]
+ }
+ ],
+ "contributorsPerLine": 7,
+ "projectName": "ml5-library",
+ "projectOwner": "ml5js",
+ "repoType": "github",
+ "repoHost": "https://github.com"
+}
diff --git a/.gitignore b/.gitignore
index b1b6f2356..e8e770bce 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,7 @@ training/lstm/data/t
.cache
/public
/static
+/dist
website/translated_docs
website/build/
diff --git a/README.md b/README.md
index 11ec80150..f6339197d 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,6 @@
# ![ml5](https://user-images.githubusercontent.com/10605821/41332516-2ee26714-6eac-11e8-83e4-a40b8761e764.png)
+[![All Contributors](https://img.shields.io/badge/all_contributors-50-orange.svg?style=flat-square)](#contributors)
+[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors)
[![BrowserStack Status](https://www.browserstack.com/automate/badge.svg?badge_key=QVNDdlkvMzNYSmhRRWlweXlIOTBENkd0MDBCOUJlbmFVZFRneFIzNlh4az0tLXA4S0loSGNlVUc2V2I3cVdLdXBKdGc9PQ==--8a5e5bfd3eafbba0702c02ec57ffec9d627a78ef)](https://www.browserstack.com/automate/public-build/QVNDdlkvMzNYSmhRRWlweXlIOTBENkd0MDBCOUJlbmFVZFRneFIzNlh4az0tLXA4S0loSGNlVUc2V2I3cVdLdXBKdGc9PQ==--8a5e5bfd3eafbba0702c02ec57ffec9d627a78ef)[![Version](https://img.shields.io/npm/v/ml5.svg?style=flat-square)](https://www.npmjs.com/package/ml5)
[![Twitter Follow](https://img.shields.io/twitter/follow/espadrine.svg?style=social&label=Follow)](https://twitter.com/ml5js)
@@ -19,15 +21,20 @@ ml5.js is heavily inspired by [Processing](https://processing.org/) and [p5.js](
There are several ways you can use the ml5.js library:
-* You can use the latest version (0.3.0) by adding it to the head section of your HTML document:
+* You can use the latest version (0.3.1) by adding it to the head section of your HTML document:
-**v0.3.0**
+**v0.3.1**
```javascript
-
+
```
* If you need to use an earlier version for any reason, you can change the version number.
+**v0.3.0**
+```javascript
+
+```
+
**v0.2.3**
```javascript
@@ -79,4 +86,14 @@ ml5.js is supported by the time and dedication of open source developers from al
Many thanks [BrowserStack](https://www.browserstack.com/) for providing testing support.
+## Contributors
+
+Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):
+
+
+
+
+
+
+This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
\ No newline at end of file
diff --git a/assets/bird.jpg b/assets/bird.jpg
new file mode 100644
index 000000000..42c72fd8f
Binary files /dev/null and b/assets/bird.jpg differ
diff --git a/karma.conf.js b/karma.conf.js
index f29f87471..ad4123f58 100644
--- a/karma.conf.js
+++ b/karma.conf.js
@@ -38,6 +38,9 @@ module.exports = (config) => {
optimization: {
minimize: false,
},
+ node: {
+ fs: "empty"
+ }
},
webpackMiddleware: {
noInfo: true,
diff --git a/package-lock.json b/package-lock.json
index caa49d0ec..13c1cffd1 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1,6 +1,6 @@
{
"name": "ml5",
- "version": "0.2.3",
+ "version": "0.3.0",
"lockfileVersion": 1,
"requires": true,
"dependencies": {
@@ -319,6 +319,11 @@
"integrity": "sha512-ONhaKPIufzzrlNbqtWFFd+jlnemX6lJAgq9ZeiZtS7I1PIf/la7CW4m83rTXRnVnsMbW2k56pGYu7AUFJD9Pow==",
"dev": true
},
+ "@tensorflow-models/body-pix": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/@tensorflow-models/body-pix/-/body-pix-1.0.1.tgz",
+ "integrity": "sha512-s63eE+ns6ArGZ6MiX7t95yzUPUxdUkEUFwG8jWkhAlDSANiqFFM1rmNcIp2iF2ldojftFRr2y7vwHGocFYrDdQ=="
+ },
"@tensorflow-models/knn-classifier": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/@tensorflow-models/knn-classifier/-/knn-classifier-1.0.0.tgz",
@@ -334,6 +339,11 @@
"resolved": "https://registry.npmjs.org/@tensorflow-models/posenet/-/posenet-1.0.0.tgz",
"integrity": "sha512-uq/evoEMn3rD4J+yYIEp9S62bXn7mbMTEpSnnBz+rCEuzEEQkwNJByy99+K+OZfFPP71OtIkngDvrBhZxRe0OQ=="
},
+ "@tensorflow-models/speech-commands": {
+ "version": "0.3.8",
+ "resolved": "https://registry.npmjs.org/@tensorflow-models/speech-commands/-/speech-commands-0.3.8.tgz",
+ "integrity": "sha512-vuMennUQX4W7sBgo5F6TIFEXAx679qJADA2+d2//z0JODf5xCx038vq339K/cJzEIXGBlprbREFlYQUkojl8jQ=="
+ },
"@tensorflow/tfjs": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs/-/tfjs-1.1.2.tgz",
diff --git a/package.json b/package.json
index df00a1f6e..5bb304f4c 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "ml5",
- "version": "0.3.0",
+ "version": "0.3.1",
"description": "A friendly machine learning library for the web.",
"main": "dist/ml5.min.js",
"directories": {
@@ -92,6 +92,8 @@
"@tensorflow-models/knn-classifier": "1.0.0",
"@tensorflow-models/mobilenet": "1.0.0",
"@tensorflow-models/posenet": "1.0.0",
+ "@tensorflow-models/speech-commands": "0.3.8",
+ "@tensorflow-models/body-pix": "1.0.1",
"@tensorflow/tfjs": "1.1.2",
"events": "^3.0.0"
}
diff --git a/src/BodyPix/index.js b/src/BodyPix/index.js
new file mode 100644
index 000000000..62dd1665f
--- /dev/null
+++ b/src/BodyPix/index.js
@@ -0,0 +1,428 @@
+// Copyright (c) 2019 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+/* eslint prefer-destructuring: ["error", {AssignmentExpression: {array: false}}] */
+/* eslint no-await-in-loop: "off" */
+
+/*
+ * BodyPix: Real-time Person Segmentation in the Browser
+ * Ported and integrated from all the hard work by: https://github.com/tensorflow/tfjs-models/tree/master/body-pix
+ */
+
+import * as tf from '@tensorflow/tfjs';
+import * as bp from '@tensorflow-models/body-pix';
+import callCallback from '../utils/callcallback';
+import * as p5Utils from '../utils/p5Utils';
+
+const DEFAULTS = {
+ "multiplier": 0.75,
+ "outputStride": 16,
+ "segmentationThreshold": 0.5,
+ "palette": {
+ // "none": {
+ // "id": -1,
+ // "color": [0, 0, 0]
+ // },
+ "leftFace": {
+ "id": 0,
+ "color": [110, 64, 170]
+ },
+ "rightFace": {
+ "id": 1,
+ "color": [106, 72, 183]
+ },
+ "rightUpperLegFront": {
+ "id": 2,
+ "color": [100, 81, 196]
+ },
+ "rightLowerLegBack": {
+ "id": 3,
+ "color": [92, 91, 206]
+ },
+ "rightUpperLegBack": {
+ "id": 4,
+ "color": [84, 101, 214]
+ },
+ "leftLowerLegFront": {
+ "id": 5,
+ "color": [75, 113, 221]
+ },
+ "leftUpperLegFront": {
+ "id": 6,
+ "color": [66, 125, 224]
+ },
+ "leftUpperLegBack": {
+ "id": 7,
+ "color": [56, 138, 226]
+ },
+ "leftLowerLegBack": {
+ "id": 8,
+ "color": [48, 150, 224]
+ },
+ "rightFeet": {
+ "id": 9,
+ "color": [40, 163, 220]
+ },
+ "rightLowerLegFront": {
+ "id": 10,
+ "color": [33, 176, 214]
+ },
+ "leftFeet": {
+ "id": 11,
+ "color": [29, 188, 205]
+ },
+ "torsoFront": {
+ "id": 12,
+ "color": [26, 199, 194]
+ },
+ "torsoBack": {
+ "id": 13,
+ "color": [26, 210, 182]
+ },
+ "rightUpperArmFront": {
+ "id": 14,
+ "color": [28, 219, 169]
+ },
+ "rightUpperArmBack": {
+ "id": 15,
+ "color": [33, 227, 155]
+ },
+ "rightLowerArmBack": {
+ "id": 16,
+ "color": [41, 234, 141]
+ },
+ "leftLowerArmFront": {
+ "id": 17,
+ "color": [51, 240, 128]
+ },
+ "leftUpperArmFront": {
+ "id": 18,
+ "color": [64, 243, 116]
+ },
+ "leftUpperArmBack": {
+ "id": 19,
+ "color": [79, 246, 105]
+ },
+ "leftLowerArmBack": {
+ "id": 20,
+ "color": [96, 247, 97]
+ },
+ "rightHand": {
+ "id": 21,
+ "color": [115, 246, 91]
+ },
+ "rightLowerArmFront": {
+ "id": 22,
+ "color": [134, 245, 88]
+ },
+ "leftHand": {
+ "id": 23,
+ "color": [155, 243, 88]
+ }
+ }
+}
+
+class BodyPix {
+ /**
+ * Create BodyPix.
+ * @param {HTMLVideoElement} video - An HTMLVideoElement.
+ * @param {object} options - An object with options.
+ * @param {function} callback - A callback to be called when the model is ready.
+ */
+ constructor(video, options, callback) {
+ this.video = video;
+ this.model = null;
+ this.modelReady = false;
+ this.modelPath = ''
+ this.config = {
+ multiplier: options.multiplier || DEFAULTS.multiplier,
+ outputStride: options.outputStride || DEFAULTS.outputStride,
+ segmentationThreshold: options.segmentationThreshold || DEFAULTS.segmentationThreshold,
+ palette: options.palette || DEFAULTS.palette
+ }
+
+ this.ready = callCallback(this.loadModel(), callback);
+ }
+
+ /**
+ * Load the model and set it to this.model
+ * @return {this} the BodyPix model.
+ */
+ async loadModel() {
+ this.model = await bp.load(this.config.multiplier);
+ this.modelReady = true;
+ return this;
+ }
+
+ /**
+ * Returns an rgb array
+ * @param {Object} a p5.Color obj
+ * @return {Array} an [r,g,b] array
+ */
+ /* eslint class-methods-use-this: "off" */
+ p5Color2RGB(p5ColorObj) {
+ const regExp = /\(([^)]+)\)/;
+ const match = regExp.exec(p5ColorObj.toString('rgb'));
+ const [r, g, b] = match[1].split(',')
+ return [r, g, b]
+ }
+
+ /**
+ * Returns a bodyPartsSpec object
+ * @param {Array} an array of [r,g,b] colors
+ * @return {object} an object with the bodyParts by color and id
+ */
+ /* eslint class-methods-use-this: "off" */
+ bodyPartsSpec(colorOptions) {
+ const result = colorOptions !== undefined || Object.keys(colorOptions).length >= 24 ? colorOptions : this.config.palette;
+
+ // Check if we're getting p5 colors, make sure they are rgb
+ if (p5Utils.checkP5() && result !== undefined && Object.keys(result).length >= 24) {
+ // Ensure the p5Color object is an RGB array
+ Object.keys(result).forEach(part => {
+ if (result[part].color instanceof window.p5.Color) {
+ result[part].color = this.p5Color2RGB(result[part].color);
+ } else {
+ result[part].color = result[part].color;
+ }
+ });
+ }
+
+ return result;
+ }
+
+ /**
+ * Segments the image with partSegmentation, return result object
+ * @param {HTMLImageElement | HTMLCanvasElement | object | function | number} imageToSegment -
+ * takes any of the following params
+ * @param {object} segmentationOptions - config params for the segmentation
+ * includes outputStride, segmentationThreshold
+ * @return {Object} a result object with image, raw, bodyParts
+ */
+ async segmentWithPartsInternal(imgToSegment, segmentationOptions) {
+ // estimatePartSegmentation
+ await this.ready;
+ await tf.nextFrame();
+
+ if (this.video && this.video.readyState === 0) {
+ await new Promise(resolve => {
+ this.video.onloadeddata = () => resolve();
+ });
+ }
+
+ this.config.palette = segmentationOptions.palette || this.config.palette;
+ this.config.outputStride = segmentationOptions.outputStride || this.config.outputStride;
+ this.config.segmentationThreshold = segmentationOptions.segmentationThreshold || this.config.segmentationThreshold;
+
+ const bodyPartsMeta = this.bodyPartsSpec(this.config.palette);
+ const segmentation = await this.model.estimatePartSegmentation(imgToSegment, this.config.outputStride, this.config.segmentationThreshold);
+
+ const colorsArray = Object.keys(bodyPartsMeta).map(part => bodyPartsMeta[part].color)
+
+ const result = {};
+ result.image = bp.toColoredPartImageData(segmentation, colorsArray);
+ result.raw = segmentation;
+ result.bodyParts = bodyPartsMeta;
+
+ if (p5Utils.checkP5()) {
+ const blob1 = await p5Utils.rawToBlob(result.image.data, segmentation.width, segmentation.height);
+ const p5Image1 = await p5Utils.blobToP5Image(blob1);
+ result.image = p5Image1;
+ }
+
+ return result;
+
+ }
+
+ /**
+ * Segments the image with partSegmentation
+ * @param {HTMLImageElement | HTMLCanvasElement | object | function | number} optionsOrCallback -
+ * takes any of the following params
+ * @param {object} configOrCallback - config params for the segmentation
+ * includes palette, outputStride, segmentationThreshold
+ * @param {function} cb - a callback function that handles the results of the function.
+ * @return {function} a promise or the results of a given callback, cb.
+ */
+ async segmentWithParts(optionsOrCallback, configOrCallback, cb) {
+ let imgToSegment = this.video;
+ let callback;
+ let segmentationOptions = this.config;
+
+ // Handle the image to predict
+ if (typeof optionsOrCallback === 'function') {
+ imgToSegment = this.video;
+ callback = optionsOrCallback;
+ // clean the following conditional statement up!
+ } else if (optionsOrCallback instanceof HTMLImageElement) {
+ imgToSegment = optionsOrCallback;
+ } else if (
+ typeof optionsOrCallback === 'object' &&
+ optionsOrCallback.elt instanceof HTMLImageElement
+ ) {
+ imgToSegment = optionsOrCallback.elt; // Handle p5.js image
+ } else if (optionsOrCallback instanceof HTMLCanvasElement) {
+ imgToSegment = optionsOrCallback;
+ } else if (
+ typeof optionsOrCallback === 'object' &&
+ optionsOrCallback.elt instanceof HTMLCanvasElement
+ ) {
+ imgToSegment = optionsOrCallback.elt; // Handle p5.js image
+ } else if (
+ typeof optionsOrCallback === 'object' &&
+ optionsOrCallback.canvas instanceof HTMLCanvasElement
+ ) {
+ imgToSegment = optionsOrCallback.canvas; // Handle p5.js image
+ } else if (!(this.video instanceof HTMLVideoElement)) {
+ // Handle unsupported input
+ throw new Error(
+ 'No input image provided. If you want to classify a video, pass the video element in the constructor. ',
+ );
+ }
+
+ if (typeof configOrCallback === 'object') {
+ segmentationOptions = configOrCallback;
+ } else if (typeof configOrCallback === 'function') {
+ callback = configOrCallback;
+ }
+
+ if (typeof cb === 'function') {
+ callback = cb;
+ }
+
+ return callCallback(this.segmentWithPartsInternal(imgToSegment, segmentationOptions), callback);
+
+ }
+
+ /**
+ * Segments the image with personSegmentation, return result object
+ * @param {HTMLImageElement | HTMLCanvasElement | object | function | number} imageToSegment -
+ * takes any of the following params
+ * @param {object} segmentationOptions - config params for the segmentation
+ * includes outputStride, segmentationThreshold
+ * @return {Object} a result object with maskBackground, maskPerson, raw
+ */
+ async segmentInternal(imgToSegment, segmentationOptions) {
+ await this.ready;
+ await tf.nextFrame();
+
+ if (this.video && this.video.readyState === 0) {
+ await new Promise(resolve => {
+ this.video.onloadeddata = () => resolve();
+ });
+ }
+
+ this.config.outputStride = segmentationOptions.outputStride || this.config.outputStride;
+ this.config.segmentationThreshold = segmentationOptions.segmentationThreshold || this.config.segmentationThreshold;
+
+ const segmentation = await this.model.estimatePersonSegmentation(imgToSegment, this.config.outputStride, this.config.segmentationThreshold)
+
+ const result = {};
+ result.maskBackground = bp.toMaskImageData(segmentation, true);
+ result.maskPerson = bp.toMaskImageData(segmentation, false);
+ result.raw = segmentation;
+
+ if (p5Utils.checkP5()) {
+ const blob1 = await p5Utils.rawToBlob(result.maskBackground.data, segmentation.width, segmentation.height);
+ const blob2 = await p5Utils.rawToBlob(result.maskPerson.data, segmentation.width, segmentation.height);
+ const p5Image1 = await p5Utils.blobToP5Image(blob1);
+ const p5Image2 = await p5Utils.blobToP5Image(blob2);
+
+ result.maskBackground = p5Image1;
+ result.maskPerson = p5Image2;
+ }
+
+ return result;
+
+ }
+
+ /**
+ * Segments the image with personSegmentation
+ * @param {HTMLImageElement | HTMLCanvasElement | object | function | number} optionsOrCallback -
+ * takes any of the following params
+ * @param {object} configOrCallback - config params for the segmentation
+ * includes outputStride, segmentationThreshold
+ * @param {function} cb - a callback function that handles the results of the function.
+ * @return {function} a promise or the results of a given callback, cb.
+ */
+ async segment(optionsOrCallback, configOrCallback, cb) {
+ let imgToSegment = this.video;
+ let callback;
+ let segmentationOptions = this.config;
+
+ // Handle the image to predict
+ if (typeof optionsOrCallback === 'function') {
+ imgToSegment = this.video;
+ callback = optionsOrCallback;
+ // clean the following conditional statement up!
+ } else if (optionsOrCallback instanceof HTMLImageElement) {
+ imgToSegment = optionsOrCallback;
+ } else if (
+ typeof optionsOrCallback === 'object' &&
+ optionsOrCallback.elt instanceof HTMLImageElement
+ ) {
+ imgToSegment = optionsOrCallback.elt; // Handle p5.js image
+ } else if (optionsOrCallback instanceof HTMLCanvasElement) {
+ imgToSegment = optionsOrCallback;
+ } else if (
+ typeof optionsOrCallback === 'object' &&
+ optionsOrCallback.elt instanceof HTMLCanvasElement
+ ) {
+ imgToSegment = optionsOrCallback.elt; // Handle p5.js image
+ } else if (
+ typeof optionsOrCallback === 'object' &&
+ optionsOrCallback.canvas instanceof HTMLCanvasElement
+ ) {
+ imgToSegment = optionsOrCallback.canvas; // Handle p5.js image
+ } else if (!(this.video instanceof HTMLVideoElement)) {
+ // Handle unsupported input
+ throw new Error(
+ 'No input image provided. If you want to classify a video, pass the video element in the constructor. ',
+ );
+ }
+
+ if (typeof configOrCallback === 'object') {
+ segmentationOptions = configOrCallback;
+ } else if (typeof configOrCallback === 'function') {
+ callback = configOrCallback;
+ }
+
+ if (typeof cb === 'function') {
+ callback = cb;
+ }
+
+ return callCallback(this.segmentInternal(imgToSegment, segmentationOptions), callback);
+ }
+
+}
+
+const bodyPix = (videoOrOptionsOrCallback, optionsOrCallback, cb) => {
+ let video;
+ let options = {};
+ let callback = cb;
+
+ if (videoOrOptionsOrCallback instanceof HTMLVideoElement) {
+ video = videoOrOptionsOrCallback;
+ } else if (
+ typeof videoOrOptionsOrCallback === 'object' &&
+ videoOrOptionsOrCallback.elt instanceof HTMLVideoElement
+ ) {
+ video = videoOrOptionsOrCallback.elt; // Handle a p5.js video element
+ } else if (typeof videoOrOptionsOrCallback === 'object') {
+ options = videoOrOptionsOrCallback;
+ } else if (typeof videoOrOptionsOrCallback === 'function') {
+ callback = videoOrOptionsOrCallback;
+ }
+
+ if (typeof optionsOrCallback === 'object') {
+ options = optionsOrCallback;
+ } else if (typeof optionsOrCallback === 'function') {
+ callback = optionsOrCallback;
+ }
+
+ const instance = new BodyPix(video, options, callback);
+ return callback ? instance : instance.ready;
+}
+
+export default bodyPix;
\ No newline at end of file
diff --git a/src/BodyPix/index_test.js b/src/BodyPix/index_test.js
new file mode 100644
index 000000000..a3d8279b6
--- /dev/null
+++ b/src/BodyPix/index_test.js
@@ -0,0 +1,63 @@
+// Copyright (c) 2018 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+const { bodyPix } = ml5;
+
+const BODYPIX_DEFAULTS = {
+ "multiplier": 0.75,
+ "outputStride": 16,
+ "segmentationThreshold": 0.5
+};
+
+describe('bodyPix', () => {
+ let bp;
+
+ async function getImage() {
+ const img = new Image();
+ img.crossOrigin = true;
+ img.src = 'https://cdn.jsdelivr.net/gh/ml5js/ml5-data-and-models@master/tests/images/harriet_128x128.jpg';
+ await new Promise((resolve) => { img.onload = resolve; });
+ return img;
+ }
+
+ async function getCanvas() {
+ const img = await getImage();
+ const canvas = document.createElement('canvas');
+ canvas.width = img.width;
+ canvas.height = img.height;
+ canvas.getContext('2d').drawImage(img, 0, 0);
+ return canvas;
+ }
+
+ beforeEach(async () => {
+ jasmine.DEFAULT_TIMEOUT_INTERVAL = 5000;
+ bp = await bodyPix();
+ });
+
+ it('Should create bodyPix with all the defaults', async () => {
+ expect(bp.config.multiplier).toBe(BODYPIX_DEFAULTS.multiplier);
+ expect(bp.config.outputStride).toBe(BODYPIX_DEFAULTS.outputStride);
+ expect(bp.config.segmentationThreshold).toBe(BODYPIX_DEFAULTS.segmentationThreshold);
+ });
+
+ describe('segmentation', () => {
+ it('Should segment an image of a Harriet Tubman with a width and height of 128', async () => {
+ const img = await getImage();
+ await bp.segment(img)
+ .then(results => {
+ expect(results.maskBackground.width).toBe(128);
+ expect(results.maskBackground.height).toBe(128);
+
+ expect(results.maskPerson.width).toBe(128);
+ expect(results.maskPerson.height).toBe(128);
+
+ expect(results.raw.width).toBe(128);
+ expect(results.raw.height).toBe(128);
+
+ })
+ });
+
+ });
+});
diff --git a/src/DCGAN/index.js b/src/DCGAN/index.js
index 5cfa1c1ac..a706523bc 100644
--- a/src/DCGAN/index.js
+++ b/src/DCGAN/index.js
@@ -12,26 +12,28 @@ import * as tf from '@tensorflow/tfjs';
import callCallback from '../utils/callcallback';
import * as p5Utils from '../utils/p5Utils';
-const allModelInfo = {
- face: {
- description: 'DCGAN, human faces, 64x64',
- modelUrl: "https://raw.githubusercontent.com/viztopia/ml5dcgan/master/model/model.json", // "https://github.com/viztopia/ml5dcgan/blob/master/model/model.json",
- modelSize: 64,
- modelLatentDim: 128
- }
-};
+// Default pre-trained face model
+
+// const DEFAULT = {
+// "description": "DCGAN, human faces, 64x64",
+// "model": "https://raw.githubusercontent.com/ml5js/ml5-data-and-models/master/models/dcgan/face/model.json",
+// "modelSize": 64,
+// "modelLatentDim": 128
+// }
-class DCGANBase{
+class DCGANBase {
/**
* Create an DCGAN.
* @param {modelName} modelName - The name of the model to use.
* @param {function} readyCb - A callback to be called when the model is ready.
*/
- constructor(modelName, readyCb){
- this.modelCache = {};
- this.modelName = modelName;
- this.model = null;
- this.ready = callCallback(this.loadModel(), readyCb);
+ constructor(modelPath, callback) {
+ this.model = {};
+ this.modelPath = modelPath;
+ this.modelInfo = {};
+ this.modelPathPrefix = '';
+ this.modelReady = false;
+ this.ready = callCallback(this.loadModel(), callback);
}
/**
@@ -39,17 +41,16 @@ class DCGANBase{
* @return {this} the dcgan.
*/
async loadModel() {
- const {modelName} = this;
- const modelInfo = allModelInfo[modelName];
- const {modelUrl} = modelInfo;
+ const modelInfo = await fetch(this.modelPath);
+ const modelInfoJson = await modelInfo.json();
- if (modelName in this.modelCache) {
- this.model = this.modelCache[modelName];
- return this;
- }
+ this.modelInfo = modelInfoJson
+
+ const [modelUrl] = this.modelPath.split('manifest.json')
+ const modelJsonPath = this.isAbsoluteURL(modelUrl) ? this.modelInfo.model : this.modelPathPrefix + this.modelInfo.model
- this.model = await tf.loadLayersModel(modelUrl);
- this.modelCache[modelName] = this.model;
+ this.model = await tf.loadLayersModel(modelJsonPath);
+ this.modelReady = true;
return this;
}
@@ -58,10 +59,11 @@ class DCGANBase{
* @param {function} callback - a callback function handle the results of generate
* @return {object} a promise or the result of the callback function.
*/
- async generate(callback){
+ async generate(callback) {
+ await this.ready;
return callCallback(this.generateInternal(), callback);
}
-
+
/**
* Computes what will become the image tensor
* @param {number} latentDim - the number of latent dimensions to pass through
@@ -83,30 +85,30 @@ class DCGANBase{
* @return {object} includes blob, raw, and tensor. if P5 exists, then a p5Image
*/
async generateInternal() {
- const modelInfo = allModelInfo[this.modelName];
- const {modelLatentDim} = modelInfo;
+ const {
+ modelLatentDim
+ } = this.modelInfo;
const imageTensor = await this.compute(modelLatentDim);
// get the raw data from tensor
const raw = await tf.browser.toPixels(imageTensor);
-
// get the blob from raw
const [imgHeight, imgWidth] = imageTensor.shape;
const blob = await p5Utils.rawToBlob(raw, imgWidth, imgHeight);
// get the p5.Image object
let p5Image;
- if(p5Utils.checkP5()){
+ if (p5Utils.checkP5()) {
p5Image = await p5Utils.blobToP5Image(blob);
}
// wrap up the final js result object
- const result = {};
+ const result = {};
result.blob = blob;
result.raw = raw;
result.tensor = imageTensor;
- if(p5Utils.checkP5()){
+ if (p5Utils.checkP5()) {
result.image = p5Image;
}
@@ -114,8 +116,34 @@ class DCGANBase{
}
+
+ /* eslint class-methods-use-this: "off" */
+ isAbsoluteURL(str) {
+ const pattern = new RegExp('^(?:[a-z]+:)?//', 'i');
+ return !!pattern.test(str);
+ }
+
}
-const DCGAN = (modelName, callback) => new DCGANBase( modelName, callback ) ;
+const DCGAN = (modelPath, cb) => {
+
+ if (typeof modelPath !== 'string') {
+ throw new Error(`Please specify a path to a "manifest.json" file: \n
+ "models/face/manifest.json" \n\n
+ This "manifest.json" file should include:\n
+ {
+ "description": "DCGAN, human faces, 64x64",
+ "model": "https://raw.githubusercontent.com/viztopia/ml5dcgan/master/model/model.json", // "https://github.com/viztopia/ml5dcgan/blob/master/model/model.json",
+ "modelSize": 64,
+ "modelLatentDim": 128
+ }
+ `);
+ }
+
+
+ const instance = new DCGANBase(modelPath, cb);
+ return cb ? instance : instance.ready;
+
+}
export default DCGAN;
diff --git a/src/FeatureExtractor/Mobilenet.js b/src/FeatureExtractor/Mobilenet.js
index 4ea52bd0d..99d1372c7 100644
--- a/src/FeatureExtractor/Mobilenet.js
+++ b/src/FeatureExtractor/Mobilenet.js
@@ -65,6 +65,7 @@ class Mobilenet {
*/
this.hasAnyTrainedClass = false;
this.customModel = null;
+ this.jointModel = null;
this.config = {
epochs: options.epochs || DEFAULTS.epochs,
version: options.version || DEFAULTS.version,
@@ -104,9 +105,9 @@ class Mobilenet {
const layer = this.mobilenet.getLayer(this.config.layer);
this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });
- // if (this.video) {
- // await this.mobilenet.classify(imgToTensor(this.video)); // Warm up
- // }
+ if (this.video) {
+ await this.mobilenetFeatures.predict(imgToTensor(this.video)); // Warm up
+ }
return this;
}
@@ -300,6 +301,9 @@ class Mobilenet {
],
});
}
+ this.jointModel = tf.sequential();
+ this.jointModel.add(this.mobilenetFeatures); // mobilenet
+ this.jointModel.add(this.customModel); // transfer layer
const optimizer = tf.train.adam(this.config.learningRate);
this.customModel.compile({ optimizer, loss: this.loss });
@@ -358,8 +362,7 @@ class Mobilenet {
const predictedClasses = tf.tidy(() => {
const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE];
const processedImg = imgToTensor(imgToPredict, imageResize);
- const activation = this.mobilenetFeatures.predict(processedImg);
- const predictions = this.customModel.predict(activation);
+ const predictions = this.jointModel.predict(processedImg);
return Array.from(predictions.as1D().dataSync());
});
const results = await predictedClasses.map((confidence, index) => {
@@ -407,8 +410,7 @@ class Mobilenet {
const predictedClass = tf.tidy(() => {
const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE];
const processedImg = imgToTensor(imgToPredict, imageResize);
- const activation = this.mobilenetFeatures.predict(processedImg);
- const predictions = this.customModel.predict(activation);
+ const predictions = this.jointModel.predict(processedImg);
return predictions.as1D();
});
const prediction = await predictedClass.data();
@@ -425,31 +427,37 @@ class Mobilenet {
model = file;
const fr = new FileReader();
fr.onload = (d) => {
- this.mapStringToIndex = JSON.parse(d.target.result).ml5Specs.mapStringToIndex;
+ if (JSON.parse(d.target.result).ml5Specs) {
+ this.mapStringToIndex = JSON.parse(d.target.result).ml5Specs.mapStringToIndex;
+ }
};
fr.readAsText(file);
} else if (file.name.includes('.bin')) {
weights = file;
}
});
- this.customModel = await tf.loadLayersModel(tf.io.browserFiles([model, weights]));
+ this.jointModel = await tf.loadLayersModel(tf.io.browserFiles([model, weights]));
} else {
fetch(filesOrPath)
.then(r => r.json())
- .then((r) => { this.mapStringToIndex = r.ml5Specs.mapStringToIndex; });
- this.customModel = await tf.loadLayersModel(filesOrPath);
+ .then((r) => {
+ if (r.ml5Specs) {
+ this.mapStringToIndex = r.ml5Specs.mapStringToIndex;
+ }
+ });
+ this.jointModel = await tf.loadLayersModel(filesOrPath);
if (callback) {
callback();
}
}
- return this.customModel;
+ return this.jointModel;
}
async save(callback, name) {
- if (!this.customModel) {
+ if (!this.jointModel) {
throw new Error('No model found.');
}
- this.customModel.save(tf.io.withSaveHandler(async (data) => {
+ this.jointModel.save(tf.io.withSaveHandler(async (data) => {
let modelName = 'model';
if(name) modelName = name;
diff --git a/src/ImageClassifier/darknet.js b/src/ImageClassifier/darknet.js
index ba0e0bbf8..ceda257b2 100644
--- a/src/ImageClassifier/darknet.js
+++ b/src/ImageClassifier/darknet.js
@@ -4,49 +4,22 @@
// https://opensource.org/licenses/MIT
import * as tf from '@tensorflow/tfjs';
+import { getTopKClassesFromTensor } from '../utils/gettopkclasses';
import IMAGENET_CLASSES_DARKNET from '../utils/IMAGENET_CLASSES_DARKNET';
const DEFAULTS = {
- DARKNET_URL: 'https://rawgit.com/ml5js/ml5-data-and-models/master/models/darknetclassifier/darknetreference/model.json',
- DARKNET_TINY_URL: 'https://rawgit.com/ml5js/ml5-data-and-models/master/models/darknetclassifier/darknettiny/model.json',
+ DARKNET_URL: 'https://cdn.jsdelivr.net/gh/ml5js/ml5-data-and-models@master/models/darknetclassifier/darknetreference/model.json',
+ DARKNET_TINY_URL: 'https://cdn.jsdelivr.net/gh/ml5js/ml5-data-and-models@master/models/darknetclassifier/darknettiny/model.json',
IMAGE_SIZE_DARKNET: 256,
IMAGE_SIZE_DARKNET_TINY: 224,
};
-async function getTopKClasses(logits, topK) {
- const values = await logits.data();
- const valuesAndIndices = [];
- for (let i = 0; i < values.length; i += 1) {
- valuesAndIndices.push({
- value: values[i],
- index: i,
- });
- }
- valuesAndIndices.sort((a, b) => b.value - a.value);
-
- const topkValues = new Float32Array(topK);
- const topkIndices = new Int32Array(topK);
- for (let i = 0; i < topK; i += 1) {
- topkValues[i] = valuesAndIndices[i].value;
- topkIndices[i] = valuesAndIndices[i].index;
- }
-
- const topClassesAndProbs = [];
- for (let i = 0; i < topkIndices.length; i += 1) {
- topClassesAndProbs.push({
- className: IMAGENET_CLASSES_DARKNET[topkIndices[i]],
- probability: topkValues[i],
- });
- }
- return topClassesAndProbs;
-}
-
function preProcess(img, size) {
let image;
if (!(img instanceof tf.Tensor)) {
- if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement) {
+ if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement || img instanceof HTMLCanvasElement) {
image = tf.browser.fromPixels(img);
- } else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement)) {
+ } else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement || img.elt instanceof HTMLCanvasElement)) {
image = tf.browser.fromPixels(img.elt); // Handle p5.js image and video.
}
} else {
@@ -101,7 +74,7 @@ export class Darknet {
const predictions = this.model.predict(imgData);
return tf.softmax(predictions);
});
- const classes = await getTopKClasses(logits, topk);
+ const classes = await getTopKClassesFromTensor(logits, topk, IMAGENET_CLASSES_DARKNET);
logits.dispose();
return classes;
}
diff --git a/src/ImageClassifier/doodlenet.js b/src/ImageClassifier/doodlenet.js
new file mode 100644
index 000000000..3bfc3e922
--- /dev/null
+++ b/src/ImageClassifier/doodlenet.js
@@ -0,0 +1,67 @@
+// Copyright (c) 2018 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+import * as tf from '@tensorflow/tfjs';
+import { getTopKClassesFromTensor } from '../utils/gettopkclasses';
+import DOODLENET_CLASSES from '../utils/DOODLENET_CLASSES';
+
+const DEFAULTS = {
+ DOODLENET_URL: 'https://cdn.jsdelivr.net/gh/ml5js/ml5-data-and-models@master/models/doodlenet/model.json',
+ IMAGE_SIZE_DOODLENET: 28,
+};
+
+function preProcess(img, size) {
+ let image;
+ if (!(img instanceof tf.Tensor)) {
+ if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement || img instanceof HTMLCanvasElement) {
+ image = tf.browser.fromPixels(img);
+ } else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement || img.elt instanceof HTMLCanvasElement)) {
+ image = tf.browser.fromPixels(img.elt); // Handle p5.js image, video and canvas.
+ }
+ } else {
+ image = img;
+ }
+ const normalized = tf.scalar(1).sub(image.toFloat().div(tf.scalar(255)));
+ let resized = normalized;
+ if (normalized.shape[0] !== size || normalized.shape[1] !== size) {
+ resized = tf.image.resizeBilinear(normalized, [size, size]);
+ }
+ const [r, g, b] = tf.split(resized, 3, 2);
+ const gray = (r.add(g).add(b)).div(tf.scalar(3)).round(); // Get average r,g,b color value and round to 0 or 1
+ const batched = gray.reshape([1, size, size, 1]);
+ return batched;
+}
+
+export class Doodlenet {
+ constructor() {
+ this.imgSize = DEFAULTS.IMAGE_SIZE_DOODLENET;
+ }
+
+ async load() {
+ this.model = await tf.loadLayersModel(DEFAULTS.DOODLENET_URL);
+
+ // Warmup the model.
+ const result = tf.tidy(() => this.model.predict(tf.zeros([1, this.imgSize, this.imgSize, 1])));
+ await result.data();
+ result.dispose();
+ }
+
+ async classify(img, topk = 10) {
+ const logits = tf.tidy(() => {
+ const imgData = preProcess(img, this.imgSize);
+ const predictions = this.model.predict(imgData);
+ return predictions;
+ });
+ const classes = await getTopKClassesFromTensor(logits, topk, DOODLENET_CLASSES);
+ logits.dispose();
+ return classes;
+ }
+}
+
+export async function load() {
+ const doodlenet = new Doodlenet();
+ await doodlenet.load();
+ return doodlenet;
+}
diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js
index 23dd34941..708c6548f 100644
--- a/src/ImageClassifier/index.js
+++ b/src/ImageClassifier/index.js
@@ -10,7 +10,9 @@ Image Classifier using pre-trained networks
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import * as darknet from './darknet';
+import * as doodlenet from './doodlenet';
import callCallback from '../utils/callcallback';
+import { imgToTensor } from '../utils/imageUtilities';
const DEFAULTS = {
mobilenet: {
@@ -19,51 +21,77 @@ const DEFAULTS = {
topk: 3,
},
};
+const IMAGE_SIZE = 224;
+const MODEL_OPTIONS = ['mobilenet', 'darknet', 'darknet-tiny', 'doodlenet'];
class ImageClassifier {
/**
* Create an ImageClassifier.
- * @param {modelName} modelName - The name of the model to use. Current options
- * are: 'mobilenet', 'darknet', and 'darknet-tiny'.
+ * @param {modelNameOrUrl} modelNameOrUrl - The name or the URL of the model to use. Current model name options
+ * are: 'mobilenet', 'darknet', 'darknet-tiny', and 'doodlenet'.
* @param {HTMLVideoElement} video - An HTMLVideoElement.
* @param {object} options - An object with options.
* @param {function} callback - A callback to be called when the model is ready.
*/
- constructor(modelName, video, options, callback) {
- this.modelName = modelName;
+ constructor(modelNameOrUrl, video, options, callback) {
this.video = video;
this.model = null;
- switch (this.modelName) {
- case 'mobilenet':
- this.modelToUse = mobilenet;
- this.version = options.version || DEFAULTS.mobilenet.version;
- this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
- this.topk = options.topk || DEFAULTS.mobilenet.topk;
- break;
- case 'darknet':
- this.version = 'reference'; // this a 28mb model
- this.modelToUse = darknet;
- break;
- case 'darknet-tiny':
- this.version = 'tiny'; // this a 4mb model
- this.modelToUse = darknet;
- break;
- default:
- this.modelToUse = null;
+ this.mapStringToIndex = [];
+ if (typeof modelNameOrUrl === 'string') {
+ if (MODEL_OPTIONS.includes(modelNameOrUrl)) {
+ this.modelName = modelNameOrUrl;
+ this.modelUrl = null;
+ switch (this.modelName) {
+ case 'mobilenet':
+ this.modelToUse = mobilenet;
+ this.version = options.version || DEFAULTS.mobilenet.version;
+ this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
+ this.topk = options.topk || DEFAULTS.mobilenet.topk;
+ break;
+ case 'darknet':
+ this.version = 'reference'; // this a 28mb model
+ this.modelToUse = darknet;
+ break;
+ case 'darknet-tiny':
+ this.version = 'tiny'; // this a 4mb model
+ this.modelToUse = darknet;
+ break;
+ case 'doodlenet':
+ this.modelToUse = doodlenet;
+ break;
+ default:
+ this.modelToUse = null;
+ }
+ } else {
+ this.modelUrl = modelNameOrUrl;
+ }
}
// Load the model
- this.ready = callCallback(this.loadModel(), callback);
+ this.ready = callCallback(this.loadModel(this.modelUrl), callback);
}
/**
* Load the model and set it to this.model
* @return {this} The ImageClassifier.
*/
- async loadModel() {
- this.model = await this.modelToUse.load(this.version, this.alpha);
+ async loadModel(modelUrl) {
+ if (modelUrl) this.model = await this.loadModelFrom(modelUrl);
+ else this.model = await this.modelToUse.load(this.version, this.alpha);
return this;
}
+ async loadModelFrom(path = null) {
+ fetch(path)
+ .then(r => r.json())
+ .then((r) => {
+ if (r.ml5Specs) {
+ this.mapStringToIndex = r.ml5Specs.mapStringToIndex;
+ }
+ });
+ this.model = await tf.loadLayersModel(path);
+ return this.model;
+ }
+
/**
* Classifies the given input and returns an object with labels and confidence
* @param {HTMLImageElement | HTMLCanvasElement | HTMLVideoElement} imgToPredict -
@@ -77,11 +105,37 @@ class ImageClassifier {
await this.ready;
await tf.nextFrame();
+ if (imgToPredict instanceof HTMLVideoElement && imgToPredict.readyState === 0) {
+ const video = imgToPredict;
+ // Wait for the video to be ready
+ await new Promise(resolve => {
+ video.onloadeddata = () => resolve();
+ });
+ }
+
if (this.video && this.video.readyState === 0) {
await new Promise(resolve => {
this.video.onloadeddata = () => resolve();
});
}
+
+ if (this.modelUrl) {
+ await tf.nextFrame();
+ const predictedClasses = tf.tidy(() => {
+ const imageResize = [IMAGE_SIZE, IMAGE_SIZE];
+ const processedImg = imgToTensor(imgToPredict, imageResize);
+ const predictions = this.model.predict(processedImg);
+ return Array.from(predictions.as1D().dataSync());
+ });
+ const results = await predictedClasses.map((confidence, index) => {
+ const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index;
+ return {
+ label,
+ confidence,
+ };
+ }).sort((a, b) => b.confidence - a.confidence);
+ return results;
+ }
return this.model
.classify(imgToPredict, numberOfClasses)
.then(classes => classes.map(c => ({ label: c.className, confidence: c.probability })));
@@ -127,6 +181,13 @@ class ImageClassifier {
inputNumOrCallback.canvas instanceof HTMLCanvasElement
) {
imgToPredict = inputNumOrCallback.canvas; // Handle p5.js image
+ } else if (inputNumOrCallback instanceof HTMLVideoElement) {
+ imgToPredict = inputNumOrCallback;
+ } else if (
+ typeof inputNumOrCallback === 'object' &&
+ inputNumOrCallback.elt instanceof HTMLVideoElement
+ ) {
+ imgToPredict = inputNumOrCallback.elt; // Handle p5.js video
} else if (!(this.video instanceof HTMLVideoElement)) {
// Handle unsupported input
throw new Error(
diff --git a/src/ImageClassifier/index_test.js b/src/ImageClassifier/index_test.js
index 51b0664b6..cd9b4ec00 100644
--- a/src/ImageClassifier/index_test.js
+++ b/src/ImageClassifier/index_test.js
@@ -22,7 +22,7 @@ describe('imageClassifier', () => {
async function getImage() {
const img = new Image();
img.crossOrigin = true;
- img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
+ img.src = 'https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg';
await new Promise((resolve) => { img.onload = resolve; });
return img;
}
diff --git a/src/SoundClassifier/index.js b/src/SoundClassifier/index.js
new file mode 100644
index 000000000..e9a0a5334
--- /dev/null
+++ b/src/SoundClassifier/index.js
@@ -0,0 +1,105 @@
+// Copyright (c) 2019 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+/*
+Sound Classifier using pre-trained networks
+*/
+
+import * as tf from '@tensorflow/tfjs';
+import * as speechCommands from './speechcommands';
+import callCallback from '../utils/callcallback';
+
+const MODEL_OPTIONS = ['speechcommands18w'];
+class SoundClassifier {
+ /**
+ * Create an SoundClassifier.
+ * @param {modelNameOrUrl} modelNameOrUrl - The name or the URL of the model to use. Current name options
+ * are: 'SpeechCommands18w'.
+ * @param {object} options - An object with options.
+ * @param {function} callback - A callback to be called when the model is ready.
+ */
+ constructor(modelNameOrUrl, options, callback) {
+ this.model = null;
+ this.options = options;
+ if (typeof modelNameOrUrl === 'string') {
+ if (MODEL_OPTIONS.includes(modelNameOrUrl)) {
+ this.modelName = modelNameOrUrl;
+ this.modelUrl = null;
+ switch (this.modelName) {
+ case 'speechcommands18w':
+ this.modelToUse = speechCommands;
+ break;
+ default:
+ this.modelToUse = null;
+ }
+ } else {
+ // Default to speechCommands for now
+ this.modelToUse = speechCommands;
+ this.modelUrl = modelNameOrUrl;
+ }
+ }
+ // Load the model
+ this.ready = callCallback(this.loadModel(options, this.modelUrl), callback);
+ }
+
+ async loadModel(options) {
+ this.model = await this.modelToUse.load(options, this.modelUrl);
+ return this;
+ }
+
+ async classifyInternal(numberOfClasses, callback) {
+ // Wait for the model to be ready
+ await this.ready;
+ await tf.nextFrame();
+
+ return this.model.classify(numberOfClasses, callback);
+ }
+
+ /**
+ * Classifies the audio from microphone and takes a callback to handle the results
+ * @param {function | number} numOrCallback -
+ * takes any of the following params
+ * @param {function} cb - a callback function that handles the results of the function.
+ * @return {function} a promise or the results of a given callback, cb.
+ */
+ async classify(numOrCallback = null, cb) {
+ let numberOfClasses = this.topk;
+ let callback;
+
+ if (typeof numOrCallback === 'number') {
+ numberOfClasses = numOrCallback;
+ } else if (typeof numOrCallback === 'function') {
+ callback = numOrCallback;
+ }
+
+ if (typeof cb === 'function') {
+ callback = cb;
+ }
+ return this.classifyInternal(numberOfClasses, callback);
+ }
+}
+
+const soundClassifier = (modelName, optionsOrCallback, cb) => {
+ let model;
+ let options = {};
+ let callback = cb;
+
+ if (typeof modelName === 'string') {
+ model = modelName.toLowerCase();
+ } else {
+ throw new Error('Please specify a model to use. E.g: "SpeechCommands18w"');
+ }
+
+ if (typeof optionsOrCallback === 'object') {
+ options = optionsOrCallback;
+ } else if (typeof optionsOrCallback === 'function') {
+ callback = optionsOrCallback;
+ }
+
+ const instance = new SoundClassifier(model, options, callback);
+ return callback ? instance : instance.ready;
+};
+
+export default soundClassifier;
diff --git a/src/SoundClassifier/speechcommands.js b/src/SoundClassifier/speechcommands.js
new file mode 100644
index 000000000..996397a87
--- /dev/null
+++ b/src/SoundClassifier/speechcommands.js
@@ -0,0 +1,44 @@
+// Copyright (c) 2018 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+import * as tfjsSpeechCommands from '@tensorflow-models/speech-commands';
+import { getTopKClassesFromArray } from '../utils/gettopkclasses';
+
+export class SpeechCommands {
+ constructor(options) {
+ this.options = options;
+ }
+
+ async load(url) {
+ if (url) {
+ const split = url.split("/");
+ const prefix = split.slice(0, split.length - 1).join("/");
+ const metadataJson = `${prefix}/metadata.json`;
+ this.model = tfjsSpeechCommands.create('BROWSER_FFT', undefined, url, metadataJson);
+ } else this.model = tfjsSpeechCommands.create('BROWSER_FFT');
+ await this.model.ensureModelLoaded();
+ this.allLabels = this.model.wordLabels();
+ }
+
+ classify(topk = this.allLabels.length, cb) {
+ return this.model.listen(result => {
+ if (result.scores) {
+ const classes = getTopKClassesFromArray(result.scores, topk, this.allLabels)
+ .map(c => ({ label: c.className, confidence: c.probability }));
+ return cb(null, classes);
+ }
+ return cb(`ERROR: Cannot find scores in result: ${result}`);
+ }, this.options)
+ .catch(err => {
+ return cb(`ERROR: ${err.message}`);
+ });
+ }
+}
+
+export async function load(options, url) {
+ const speechCommandsModel = new SpeechCommands(options);
+ await speechCommandsModel.load(url);
+ return speechCommandsModel;
+}
diff --git a/src/YOLO/index_test.js b/src/YOLO/index_test.js
index 75eedef3f..ac72b7541 100644
--- a/src/YOLO/index_test.js
+++ b/src/YOLO/index_test.js
@@ -18,7 +18,7 @@ describe('YOLO', () => {
async function getRobin() {
const img = new Image();
img.crossOrigin = '';
- img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
+ img.src = 'https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg';
await new Promise((resolve) => { img.onload = resolve; });
return img;
}
diff --git a/src/index.js b/src/index.js
index 9cea52fa5..e1c535829 100644
--- a/src/index.js
+++ b/src/index.js
@@ -6,6 +6,7 @@
import * as tf from '@tensorflow/tfjs';
import pitchDetection from './PitchDetection/';
import imageClassifier from './ImageClassifier/';
+import soundClassifier from './SoundClassifier/';
import KNNClassifier from './KNNClassifier/';
import featureExtractor from './FeatureExtractor/';
import word2vec from './Word2vec/';
@@ -22,12 +23,15 @@ import DCGAN from './DCGAN';
import preloadRegister from './utils/p5PreloadHelper';
import { version } from '../package.json';
import sentiment from './Sentiment';
+import bodyPix from './BodyPix';
const withPreload = {
charRNN,
CVAE,
+ DCGAN,
featureExtractor,
imageClassifier,
+ soundClassifier,
pitchDetection,
pix2pix,
poseNet,
@@ -41,8 +45,8 @@ const withPreload = {
module.exports = Object.assign({}, preloadRegister(withPreload), {
KNNClassifier,
...imageUtils,
- DCGAN,
tf,
version,
sentiment,
+ bodyPix,
});
diff --git a/src/utils/DOODLENET_CLASSES.js b/src/utils/DOODLENET_CLASSES.js
new file mode 100644
index 000000000..9f6592fd5
--- /dev/null
+++ b/src/utils/DOODLENET_CLASSES.js
@@ -0,0 +1,353 @@
+// Copyright (c) 2018 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+/* eslint-disable */
+export default [
+ "flashlight",
+ "belt",
+ "mushroom",
+ "pond",
+ "strawberry",
+ "pineapple",
+ "sun",
+ "cow",
+ "ear",
+ "bush",
+ "pliers",
+ "watermelon",
+ "apple",
+ "baseball",
+ "feather",
+ "shoe",
+ "leaf",
+ "lollipop",
+ "crown",
+ "ocean",
+ "horse",
+ "mountain",
+ "mosquito",
+ "mug",
+ "hospital",
+ "saw",
+ "castle",
+ "angel",
+ "underwear",
+ "traffic_light",
+ "cruise_ship",
+ "marker",
+ "blueberry",
+ "flamingo",
+ "face",
+ "hockey_stick",
+ "bucket",
+ "campfire",
+ "asparagus",
+ "skateboard",
+ "door",
+ "suitcase",
+ "skull",
+ "cloud",
+ "paint_can",
+ "hockey_puck",
+ "steak",
+ "house_plant",
+ "sleeping_bag",
+ "bench",
+ "snowman",
+ "arm",
+ "crayon",
+ "fan",
+ "shovel",
+ "leg",
+ "washing_machine",
+ "harp",
+ "toothbrush",
+ "tree",
+ "bear",
+ "rake",
+ "megaphone",
+ "knee",
+ "guitar",
+ "calculator",
+ "hurricane",
+ "grapes",
+ "paintbrush",
+ "couch",
+ "nose",
+ "square",
+ "wristwatch",
+ "penguin",
+ "bridge",
+ "octagon",
+ "submarine",
+ "screwdriver",
+ "rollerskates",
+ "ladder",
+ "wine_bottle",
+ "cake",
+ "bracelet",
+ "broom",
+ "yoga",
+ "finger",
+ "fish",
+ "line",
+ "truck",
+ "snake",
+ "bus",
+ "stitches",
+ "snorkel",
+ "shorts",
+ "bowtie",
+ "pickup_truck",
+ "tooth",
+ "snail",
+ "foot",
+ "crab",
+ "school_bus",
+ "train",
+ "dresser",
+ "sock",
+ "tractor",
+ "map",
+ "hedgehog",
+ "coffee_cup",
+ "computer",
+ "matches",
+ "beard",
+ "frog",
+ "crocodile",
+ "bathtub",
+ "rain",
+ "moon",
+ "bee",
+ "knife",
+ "boomerang",
+ "lighthouse",
+ "chandelier",
+ "jail",
+ "pool",
+ "stethoscope",
+ "frying_pan",
+ "cell_phone",
+ "binoculars",
+ "purse",
+ "lantern",
+ "birthday_cake",
+ "clarinet",
+ "palm_tree",
+ "aircraft_carrier",
+ "vase",
+ "eraser",
+ "shark",
+ "skyscraper",
+ "bicycle",
+ "sink",
+ "teapot",
+ "circle",
+ "tornado",
+ "bird",
+ "stereo",
+ "mouth",
+ "key",
+ "hot_dog",
+ "spoon",
+ "laptop",
+ "cup",
+ "bottlecap",
+ "The_Great_Wall_of_China",
+ "The_Mona_Lisa",
+ "smiley_face",
+ "waterslide",
+ "eyeglasses",
+ "ceiling_fan",
+ "lobster",
+ "moustache",
+ "carrot",
+ "garden",
+ "police_car",
+ "postcard",
+ "necklace",
+ "helmet",
+ "blackberry",
+ "beach",
+ "golf_club",
+ "car",
+ "panda",
+ "alarm_clock",
+ "t-shirt",
+ "dog",
+ "bread",
+ "wine_glass",
+ "lighter",
+ "flower",
+ "bandage",
+ "drill",
+ "butterfly",
+ "swan",
+ "owl",
+ "raccoon",
+ "squiggle",
+ "calendar",
+ "giraffe",
+ "elephant",
+ "trumpet",
+ "rabbit",
+ "trombone",
+ "sheep",
+ "onion",
+ "church",
+ "flip_flops",
+ "spreadsheet",
+ "pear",
+ "clock",
+ "roller_coaster",
+ "parachute",
+ "kangaroo",
+ "duck",
+ "remote_control",
+ "compass",
+ "monkey",
+ "rainbow",
+ "tennis_racquet",
+ "lion",
+ "pencil",
+ "string_bean",
+ "oven",
+ "star",
+ "cat",
+ "pizza",
+ "soccer_ball",
+ "syringe",
+ "flying_saucer",
+ "eye",
+ "cookie",
+ "floor_lamp",
+ "mouse",
+ "toilet",
+ "toaster",
+ "The_Eiffel_Tower",
+ "airplane",
+ "stove",
+ "cello",
+ "stop_sign",
+ "tent",
+ "diving_board",
+ "light_bulb",
+ "hammer",
+ "scorpion",
+ "headphones",
+ "basket",
+ "spider",
+ "paper_clip",
+ "sweater",
+ "ice_cream",
+ "envelope",
+ "sea_turtle",
+ "donut",
+ "hat",
+ "hourglass",
+ "broccoli",
+ "jacket",
+ "backpack",
+ "book",
+ "lightning",
+ "drums",
+ "snowflake",
+ "radio",
+ "banana",
+ "camel",
+ "canoe",
+ "toothpaste",
+ "chair",
+ "picture_frame",
+ "parrot",
+ "sandwich",
+ "lipstick",
+ "pants",
+ "violin",
+ "brain",
+ "power_outlet",
+ "triangle",
+ "hamburger",
+ "dragon",
+ "bulldozer",
+ "cannon",
+ "dolphin",
+ "zebra",
+ "animal_migration",
+ "camouflage",
+ "scissors",
+ "basketball",
+ "elbow",
+ "umbrella",
+ "windmill",
+ "table",
+ "rifle",
+ "hexagon",
+ "potato",
+ "anvil",
+ "sword",
+ "peanut",
+ "axe",
+ "television",
+ "rhinoceros",
+ "baseball_bat",
+ "speedboat",
+ "sailboat",
+ "zigzag",
+ "garden_hose",
+ "river",
+ "house",
+ "pillow",
+ "ant",
+ "tiger",
+ "stairs",
+ "cooler",
+ "see_saw",
+ "piano",
+ "fireplace",
+ "popsicle",
+ "dumbbell",
+ "mailbox",
+ "barn",
+ "hot_tub",
+ "teddy-bear",
+ "fork",
+ "dishwasher",
+ "peas",
+ "hot_air_balloon",
+ "keyboard",
+ "microwave",
+ "wheel",
+ "fire_hydrant",
+ "van",
+ "camera",
+ "whale",
+ "candle",
+ "octopus",
+ "pig",
+ "swing_set",
+ "helicopter",
+ "saxophone",
+ "passport",
+ "bat",
+ "ambulance",
+ "diamond",
+ "goatee",
+ "fence",
+ "grass",
+ "mermaid",
+ "motorbike",
+ "microphone",
+ "toe",
+ "cactus",
+ "nail",
+ "telephone",
+ "hand",
+ "squirrel",
+ "streetlight",
+ "bed",
+ "firetruck",
+];
diff --git a/src/utils/gettopkclasses.js b/src/utils/gettopkclasses.js
new file mode 100644
index 000000000..86befe098
--- /dev/null
+++ b/src/utils/gettopkclasses.js
@@ -0,0 +1,38 @@
+// Copyright (c) 2018 ml5
+//
+// This software is released under the MIT License.
+// https://opensource.org/licenses/MIT
+
+export function getTopKClassesFromArray(values, topK, CLASSES) {
+ const valuesAndIndices = [];
+ for (let i = 0; i < values.length; i += 1) {
+ valuesAndIndices.push({
+ value: values[i],
+ index: i,
+ });
+ }
+ valuesAndIndices.sort((a, b) => b.value - a.value);
+
+ const topkValues = new Float32Array(topK);
+ const topkIndices = new Int32Array(topK);
+ for (let i = 0; i < topK; i += 1) {
+ topkValues[i] = valuesAndIndices[i].value;
+ topkIndices[i] = valuesAndIndices[i].index;
+ }
+
+ const topClassesAndProbs = [];
+ for (let i = 0; i < topkIndices.length; i += 1) {
+ topClassesAndProbs.push({
+ className: CLASSES[topkIndices[i]],
+ probability: topkValues[i],
+ });
+ }
+ return topClassesAndProbs;
+}
+
+export async function getTopKClassesFromTensor(logits, topK, CLASSES) {
+ const values = await logits.data();
+ return getTopKClassesFromArray(values, topK, CLASSES);
+}
+
+export default { getTopKClassesFromArray, getTopKClassesFromTensor }
diff --git a/webpack.common.babel.js b/webpack.common.babel.js
index 54437d221..bef3e4b08 100644
--- a/webpack.common.babel.js
+++ b/webpack.common.babel.js
@@ -30,5 +30,8 @@ export default {
include,
},
],
+ },
+ node: {
+ fs: "empty"
}
};