diff --git a/src/CVAE/index.js b/src/CVAE/index.js index 4e249c618..13620fd75 100644 --- a/src/CVAE/index.js +++ b/src/CVAE/index.js @@ -10,6 +10,7 @@ */ import * as tf from '@tensorflow/tfjs'; +import axios from "axios"; import callCallback from '../utils/callcallback'; import p5Utils from '../utils/p5Utils'; @@ -28,13 +29,11 @@ class Cvae { this.ready = false; this.model = {}; this.latentDim = tf.randomUniform([1, 16]); - this.modelPath = modelPath; - this.modelPathPrefix = ''; - this.jsonLoader().then(val => { - this.modelPathPrefix = this.modelPath.split('manifest.json')[0] - this.ready = callCallback(this.loadCVAEModel(this.modelPathPrefix+val.model), callback); - this.labels = val.labels; + const [modelPathPrefix] = modelPath.split('manifest.json'); + axios.get(modelPath).then(({ data }) => { + this.ready = callCallback(this.loadCVAEModel(modelPathPrefix + data.model), callback); + this.labels = data.labels; // get an array full of zero with the length of labels [0, 0, 0 ...] this.labelVector = Array(this.labels.length+1).fill(0); }); @@ -114,21 +113,6 @@ class Cvae { return { src, raws, image }; } - async jsonLoader() { - return new Promise((resolve, reject) => { - const xhr = new XMLHttpRequest(); - xhr.open('GET', this.modelPath); - - xhr.onload = () => { - const json = JSON.parse(xhr.responseText); - resolve(json); - }; - xhr.onerror = (error) => { - reject(error); - }; - xhr.send(); - }); - } } const CVAE = (model, callback) => new Cvae(model, callback); diff --git a/src/utils/checkpointLoader.js b/src/utils/checkpointLoader.js index 8830a55a2..6215b3a39 100644 --- a/src/utils/checkpointLoader.js +++ b/src/utils/checkpointLoader.js @@ -4,86 +4,119 @@ // https://opensource.org/licenses/MIT import * as tf from '@tensorflow/tfjs'; +import axios from 'axios'; const MANIFEST_FILE = 'manifest.json'; +/** + * @typedef {Record }>} Manifest + */ +/** + * Loads all of the variables of a model from a directory + * which contains a `manifest.json` file and individual variable data files. + * The `manifest.json` contains the `filename` and `shape` for each data file. + * + * @class + * @property {string} urlPath + * @property {Manifest} [checkpointManifest] + * @property {Record} variables + */ export default class CheckpointLoader { + /** + * @param {string} urlPath - the directory URL + */ constructor(urlPath) { - this.urlPath = urlPath; - if (this.urlPath.charAt(this.urlPath.length - 1) !== '/') { - this.urlPath += '/'; - } + this.urlPath = urlPath.endsWith('/') ? urlPath : `${urlPath}/`; + this.variables = {}; } + /** + * @private + * Executes the request to load the manifest.json file. + * + * @return {Promise} + */ async loadManifest() { - return new Promise((resolve, reject) => { - const xhr = new XMLHttpRequest(); - xhr.open('GET', this.urlPath + MANIFEST_FILE); - - xhr.onload = () => { - this.checkpointManifest = JSON.parse(xhr.responseText); - resolve(); - }; - xhr.onerror = (error) => { - reject(); - throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`); - }; - xhr.send(); - }); + try { + const response = await axios.get(this.urlPath + MANIFEST_FILE); + return response.data; + } catch (error) { + throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`); + } } + /** + * @private + * Executes the request to load the file for a variable. + * + * @param {string} varName + * @return {Promise} + */ + async loadVariable(varName) { + const manifest = await this.getCheckpointManifest(); + if (!(varName in manifest)) { + throw new Error(`Cannot load non-existent variable ${varName}`); + } + const { filename, shape } = manifest[varName]; + const url = this.urlPath + filename; + try { + const response = await axios.get(url, { responseType: 'arraybuffer' }); + const values = new Float32Array(response.data); + return tf.tensor(values, shape); + } catch (error) { + throw new Error(`Error loading variable ${varName} from URL ${url}: ${error}`); + } + } + /** + * @public + * Lazy-load the contents of the manifest.json file. + * + * @return {Promise} + */ async getCheckpointManifest() { - if (this.checkpointManifest == null) { - await this.loadManifest(); + if (!this.checkpointManifest) { + this.checkpointManifest = await this.loadManifest(); } return this.checkpointManifest; } + /** + * @public + * Get the property names for each variable in the manifest. + * + * @return {Promise} + */ + async getKeys() { + const manifest = await this.getCheckpointManifest(); + return Object.keys(manifest); + } + + /** + * @public + * Get a dictionary with the tensors for all variables in the manifest. + * + * @return {Promise>} + */ async getAllVariables() { - if (this.variables != null) { - return Promise.resolve(this.variables); - } - await this.getCheckpointManifest(); - const variableNames = Object.keys(this.checkpointManifest); + // Ensure that all keys are loaded and then return the dictionary. + const variableNames = await this.getKeys(); const variablePromises = variableNames.map(v => this.getVariable(v)); - return Promise.all(variablePromises).then((variables) => { - this.variables = {}; - for (let i = 0; i < variables.length; i += 1) { - this.variables[variableNames[i]] = variables[i]; - } - return this.variables; - }); + await Promise.all(variablePromises); + return this.variables; } - getVariable(varName) { - if (!(varName in this.checkpointManifest)) { - throw new Error(`Cannot load non-existent variable ${varName}`); - } - const variableRequestPromiseMethod = (resolve) => { - const xhr = new XMLHttpRequest(); - xhr.responseType = 'arraybuffer'; - const fname = this.checkpointManifest[varName].filename; - xhr.open('GET', this.urlPath + fname); - xhr.onload = () => { - if (xhr.status === 404) { - throw new Error(`Not found variable ${varName}`); - } - const values = new Float32Array(xhr.response); - const tensor = tf.tensor(values, this.checkpointManifest[varName].shape); - resolve(tensor); - }; - xhr.onerror = (error) => { - throw new Error(`Could not fetch variable ${varName}: ${error}`); - }; - xhr.send(); - }; - if (this.checkpointManifest == null) { - return new Promise((resolve) => { - this.loadManifest().then(() => { - new Promise(variableRequestPromiseMethod).then(resolve); - }); - }); + + /** + * @public + * Access a single variable from its key. Will load only if not previously loaded. + * + * @param {string} varName + * @return {Promise} + */ + async getVariable(varName) { + if (!this.variables[varName]) { + this.variables[varName] = await this.loadVariable(varName); } - return new Promise(variableRequestPromiseMethod); + return this.variables[varName]; } } diff --git a/src/utils/checkpointLoaderPix2pix.js b/src/utils/checkpointLoaderPix2pix.js index c6d2d0677..ec66938b3 100644 --- a/src/utils/checkpointLoaderPix2pix.js +++ b/src/utils/checkpointLoaderPix2pix.js @@ -1,68 +1,67 @@ -/* eslint max-len: "off" */ - import * as tf from '@tensorflow/tfjs'; +import axios from 'axios'; +/** + * Pix2Pix loads data from a '.pict' file. + * File contains the properties (name and tensor shape) for each variable + * and a huge array of numbers for all of the variables. + * Numbers must be assigned to the correct variable. + */ export default class CheckpointLoaderPix2pix { + /** + * @param {string} urlPath + */ constructor(urlPath) { + /** + * @type {string} + */ this.urlPath = urlPath; } - getAllVariables() { - return new Promise((resolve, reject) => { - const weightsCache = {}; - if (this.urlPath in weightsCache) { - resolve(weightsCache[this.urlPath]); - return; - } - - const xhr = new XMLHttpRequest(); - xhr.open('GET', this.urlPath, true); - xhr.responseType = 'arraybuffer'; - xhr.onload = () => { - if (xhr.status !== 200) { - reject(new Error('missing model')); - return; - } - const buf = xhr.response; - if (!buf) { - reject(new Error('invalid arraybuffer')); - return; - } + async getAllVariables() { + // Load the file as an ArrayBuffer. + const response = await axios.get(this.urlPath, { responseType: 'arraybuffer' }) + .catch(error => { + throw new Error(`No model found. Failed with error ${error}`); + }); + /** @type {ArrayBuffer} */ + const buf = response.data; - const parts = []; - let offset = 0; - while (offset < buf.byteLength) { - const b = new Uint8Array(buf.slice(offset, offset + 4)); - offset += 4; - const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise - parts.push(buf.slice(offset, offset + len)); - offset += len; - } + // Break data into three parts: shapes, index, and encoded. + /** @type {ArrayBuffer[]} */ + const parts = []; + let offset = 0; + while (offset < buf.byteLength) { + const b = new Uint8Array(buf.slice(offset, offset + 4)); + offset += 4; + const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise + parts.push(buf.slice(offset, offset + len)); + offset += len; + } - const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0])); - const index = new Float32Array(parts[1]); - const encoded = new Uint8Array(parts[2]); + /** @type {Array<{ name: string, shape: number[] }>} */ + const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0])); + const index = new Float32Array(parts[1]); + const encoded = new Uint8Array(parts[2]); - // decode using index - const arr = new Float32Array(encoded.length); - for (let i = 0; i < arr.length; i += 1) { - arr[i] = index[encoded[i]]; - } + // Dictionary of variables by name. + /** @type {Record} */ + const weights = {}; - const weights = {}; - offset = 0; - for (let i = 0; i < shapes.length; i += 1) { - const { shape } = shapes[i]; - const size = shape.reduce((total, num) => total * num); - const values = arr.slice(offset, offset + size); - const tfarr = tf.tensor1d(values, 'float32'); - weights[shapes[i].name] = tfarr.reshape(shape); - offset += size; - } - weightsCache[this.urlPath] = weights; - resolve(weights); - }; - xhr.send(null); + // Create a tensor for each shape. + offset = 0; + shapes.forEach(({ shape, name }) => { + const size = shape.reduce((total, num) => total * num); + // Get the raw data. + const raw = encoded.slice(offset, offset + size); + // Decode using index. + const values = new Float32Array(raw.length); + raw.forEach((value, i) => { + values[i] = index[value]; + }); + weights[name] = tf.tensor(values, shape, 'float32'); + offset += size; }); + return weights; } }