Skip to content

Commit

Permalink
Merge pull request #1361 from lindapaiste/fix/xhr-to-axios
Browse files Browse the repository at this point in the history
Rewrite XHR code using Axios
  • Loading branch information
joeyklee authored Apr 25, 2022
2 parents 3c1f132 + 4e6a427 commit a9c5098
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 136 deletions.
26 changes: 5 additions & 21 deletions src/CVAE/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*/

import * as tf from '@tensorflow/tfjs';
import axios from "axios";
import callCallback from '../utils/callcallback';
import p5Utils from '../utils/p5Utils';

Expand All @@ -28,13 +29,11 @@ class Cvae {
this.ready = false;
this.model = {};
this.latentDim = tf.randomUniform([1, 16]);
this.modelPath = modelPath;
this.modelPathPrefix = '';

this.jsonLoader().then(val => {
this.modelPathPrefix = this.modelPath.split('manifest.json')[0]
this.ready = callCallback(this.loadCVAEModel(this.modelPathPrefix+val.model), callback);
this.labels = val.labels;
const [modelPathPrefix] = modelPath.split('manifest.json');
axios.get(modelPath).then(({ data }) => {
this.ready = callCallback(this.loadCVAEModel(modelPathPrefix + data.model), callback);
this.labels = data.labels;
// get an array full of zero with the length of labels [0, 0, 0 ...]
this.labelVector = Array(this.labels.length+1).fill(0);
});
Expand Down Expand Up @@ -114,21 +113,6 @@ class Cvae {
return { src, raws, image };
}

async jsonLoader() {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
xhr.open('GET', this.modelPath);

xhr.onload = () => {
const json = JSON.parse(xhr.responseText);
resolve(json);
};
xhr.onerror = (error) => {
reject(error);
};
xhr.send();
});
}
}

const CVAE = (model, callback) => new Cvae(model, callback);
Expand Down
155 changes: 94 additions & 61 deletions src/utils/checkpointLoader.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,86 +4,119 @@
// https://opensource.org/licenses/MIT

import * as tf from '@tensorflow/tfjs';
import axios from 'axios';

const MANIFEST_FILE = 'manifest.json';

/**
* @typedef {Record<string, { filename: string, shape: Array<number> }>} Manifest
*/
/**
* Loads all of the variables of a model from a directory
* which contains a `manifest.json` file and individual variable data files.
* The `manifest.json` contains the `filename` and `shape` for each data file.
*
* @class
* @property {string} urlPath
* @property {Manifest} [checkpointManifest]
* @property {Record<string, tf.Tensor>} variables
*/
export default class CheckpointLoader {
/**
* @param {string} urlPath - the directory URL
*/
constructor(urlPath) {
this.urlPath = urlPath;
if (this.urlPath.charAt(this.urlPath.length - 1) !== '/') {
this.urlPath += '/';
}
this.urlPath = urlPath.endsWith('/') ? urlPath : `${urlPath}/`;
this.variables = {};
}

/**
* @private
* Executes the request to load the manifest.json file.
*
* @return {Promise<Manifest>}
*/
async loadManifest() {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
xhr.open('GET', this.urlPath + MANIFEST_FILE);

xhr.onload = () => {
this.checkpointManifest = JSON.parse(xhr.responseText);
resolve();
};
xhr.onerror = (error) => {
reject();
throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`);
};
xhr.send();
});
try {
const response = await axios.get(this.urlPath + MANIFEST_FILE);
return response.data;
} catch (error) {
throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`);
}
}

/**
* @private
* Executes the request to load the file for a variable.
*
* @param {string} varName
* @return {Promise<tf.Tensor>}
*/
async loadVariable(varName) {
const manifest = await this.getCheckpointManifest();
if (!(varName in manifest)) {
throw new Error(`Cannot load non-existent variable ${varName}`);
}
const { filename, shape } = manifest[varName];
const url = this.urlPath + filename;
try {
const response = await axios.get(url, { responseType: 'arraybuffer' });
const values = new Float32Array(response.data);
return tf.tensor(values, shape);
} catch (error) {
throw new Error(`Error loading variable ${varName} from URL ${url}: ${error}`);
}
}

/**
* @public
* Lazy-load the contents of the manifest.json file.
*
* @return {Promise<Manifest>}
*/
async getCheckpointManifest() {
if (this.checkpointManifest == null) {
await this.loadManifest();
if (!this.checkpointManifest) {
this.checkpointManifest = await this.loadManifest();
}
return this.checkpointManifest;
}

/**
* @public
* Get the property names for each variable in the manifest.
*
* @return {Promise<string[]>}
*/
async getKeys() {
const manifest = await this.getCheckpointManifest();
return Object.keys(manifest);
}

/**
* @public
* Get a dictionary with the tensors for all variables in the manifest.
*
* @return {Promise<Record<string, tf.Tensor>>}
*/
async getAllVariables() {
if (this.variables != null) {
return Promise.resolve(this.variables);
}
await this.getCheckpointManifest();
const variableNames = Object.keys(this.checkpointManifest);
// Ensure that all keys are loaded and then return the dictionary.
const variableNames = await this.getKeys();
const variablePromises = variableNames.map(v => this.getVariable(v));
return Promise.all(variablePromises).then((variables) => {
this.variables = {};
for (let i = 0; i < variables.length; i += 1) {
this.variables[variableNames[i]] = variables[i];
}
return this.variables;
});
await Promise.all(variablePromises);
return this.variables;
}
getVariable(varName) {
if (!(varName in this.checkpointManifest)) {
throw new Error(`Cannot load non-existent variable ${varName}`);
}
const variableRequestPromiseMethod = (resolve) => {
const xhr = new XMLHttpRequest();
xhr.responseType = 'arraybuffer';
const fname = this.checkpointManifest[varName].filename;
xhr.open('GET', this.urlPath + fname);
xhr.onload = () => {
if (xhr.status === 404) {
throw new Error(`Not found variable ${varName}`);
}
const values = new Float32Array(xhr.response);
const tensor = tf.tensor(values, this.checkpointManifest[varName].shape);
resolve(tensor);
};
xhr.onerror = (error) => {
throw new Error(`Could not fetch variable ${varName}: ${error}`);
};
xhr.send();
};
if (this.checkpointManifest == null) {
return new Promise((resolve) => {
this.loadManifest().then(() => {
new Promise(variableRequestPromiseMethod).then(resolve);
});
});

/**
* @public
* Access a single variable from its key. Will load only if not previously loaded.
*
* @param {string} varName
* @return {Promise<tf.Tensor>}
*/
async getVariable(varName) {
if (!this.variables[varName]) {
this.variables[varName] = await this.loadVariable(varName);
}
return new Promise(variableRequestPromiseMethod);
return this.variables[varName];
}
}
107 changes: 53 additions & 54 deletions src/utils/checkpointLoaderPix2pix.js
Original file line number Diff line number Diff line change
@@ -1,68 +1,67 @@
/* eslint max-len: "off" */

import * as tf from '@tensorflow/tfjs';
import axios from 'axios';

/**
* Pix2Pix loads data from a '.pict' file.
* File contains the properties (name and tensor shape) for each variable
* and a huge array of numbers for all of the variables.
* Numbers must be assigned to the correct variable.
*/
export default class CheckpointLoaderPix2pix {
/**
* @param {string} urlPath
*/
constructor(urlPath) {
/**
* @type {string}
*/
this.urlPath = urlPath;
}

getAllVariables() {
return new Promise((resolve, reject) => {
const weightsCache = {};
if (this.urlPath in weightsCache) {
resolve(weightsCache[this.urlPath]);
return;
}

const xhr = new XMLHttpRequest();
xhr.open('GET', this.urlPath, true);
xhr.responseType = 'arraybuffer';
xhr.onload = () => {
if (xhr.status !== 200) {
reject(new Error('missing model'));
return;
}
const buf = xhr.response;
if (!buf) {
reject(new Error('invalid arraybuffer'));
return;
}
async getAllVariables() {
// Load the file as an ArrayBuffer.
const response = await axios.get(this.urlPath, { responseType: 'arraybuffer' })
.catch(error => {
throw new Error(`No model found. Failed with error ${error}`);
});
/** @type {ArrayBuffer} */
const buf = response.data;

const parts = [];
let offset = 0;
while (offset < buf.byteLength) {
const b = new Uint8Array(buf.slice(offset, offset + 4));
offset += 4;
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
parts.push(buf.slice(offset, offset + len));
offset += len;
}
// Break data into three parts: shapes, index, and encoded.
/** @type {ArrayBuffer[]} */
const parts = [];
let offset = 0;
while (offset < buf.byteLength) {
const b = new Uint8Array(buf.slice(offset, offset + 4));
offset += 4;
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
parts.push(buf.slice(offset, offset + len));
offset += len;
}

const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
const index = new Float32Array(parts[1]);
const encoded = new Uint8Array(parts[2]);
/** @type {Array<{ name: string, shape: number[] }>} */
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
const index = new Float32Array(parts[1]);
const encoded = new Uint8Array(parts[2]);

// decode using index
const arr = new Float32Array(encoded.length);
for (let i = 0; i < arr.length; i += 1) {
arr[i] = index[encoded[i]];
}
// Dictionary of variables by name.
/** @type {Record<string, tf.Tensor>} */
const weights = {};

const weights = {};
offset = 0;
for (let i = 0; i < shapes.length; i += 1) {
const { shape } = shapes[i];
const size = shape.reduce((total, num) => total * num);
const values = arr.slice(offset, offset + size);
const tfarr = tf.tensor1d(values, 'float32');
weights[shapes[i].name] = tfarr.reshape(shape);
offset += size;
}
weightsCache[this.urlPath] = weights;
resolve(weights);
};
xhr.send(null);
// Create a tensor for each shape.
offset = 0;
shapes.forEach(({ shape, name }) => {
const size = shape.reduce((total, num) => total * num);
// Get the raw data.
const raw = encoded.slice(offset, offset + size);
// Decode using index.
const values = new Float32Array(raw.length);
raw.forEach((value, i) => {
values[i] = index[value];
});
weights[name] = tf.tensor(values, shape, 'float32');
offset += size;
});
return weights;
}
}

0 comments on commit a9c5098

Please sign in to comment.