Skip to content

Commit

Permalink
Rewrite CheckpointLoaderPix2pix using axios instead of xhr.
Browse files Browse the repository at this point in the history
  • Loading branch information
lindapaiste committed Apr 24, 2022
1 parent 35a003e commit 4e6a427
Showing 1 changed file with 53 additions and 54 deletions.
107 changes: 53 additions & 54 deletions src/utils/checkpointLoaderPix2pix.js
Original file line number Diff line number Diff line change
@@ -1,68 +1,67 @@
/* eslint max-len: "off" */

import * as tf from '@tensorflow/tfjs';
import axios from 'axios';

/**
* Pix2Pix loads data from a '.pict' file.
* File contains the properties (name and tensor shape) for each variable
* and a huge array of numbers for all of the variables.
* Numbers must be assigned to the correct variable.
*/
export default class CheckpointLoaderPix2pix {
/**
* @param {string} urlPath
*/
constructor(urlPath) {
/**
* @type {string}
*/
this.urlPath = urlPath;
}

getAllVariables() {
return new Promise((resolve, reject) => {
const weightsCache = {};
if (this.urlPath in weightsCache) {
resolve(weightsCache[this.urlPath]);
return;
}

const xhr = new XMLHttpRequest();
xhr.open('GET', this.urlPath, true);
xhr.responseType = 'arraybuffer';
xhr.onload = () => {
if (xhr.status !== 200) {
reject(new Error('missing model'));
return;
}
const buf = xhr.response;
if (!buf) {
reject(new Error('invalid arraybuffer'));
return;
}
async getAllVariables() {
// Load the file as an ArrayBuffer.
const response = await axios.get(this.urlPath, { responseType: 'arraybuffer' })
.catch(error => {
throw new Error(`No model found. Failed with error ${error}`);
});
/** @type {ArrayBuffer} */
const buf = response.data;

const parts = [];
let offset = 0;
while (offset < buf.byteLength) {
const b = new Uint8Array(buf.slice(offset, offset + 4));
offset += 4;
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
parts.push(buf.slice(offset, offset + len));
offset += len;
}
// Break data into three parts: shapes, index, and encoded.
/** @type {ArrayBuffer[]} */
const parts = [];
let offset = 0;
while (offset < buf.byteLength) {
const b = new Uint8Array(buf.slice(offset, offset + 4));
offset += 4;
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
parts.push(buf.slice(offset, offset + len));
offset += len;
}

const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
const index = new Float32Array(parts[1]);
const encoded = new Uint8Array(parts[2]);
/** @type {Array<{ name: string, shape: number[] }>} */
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
const index = new Float32Array(parts[1]);
const encoded = new Uint8Array(parts[2]);

// decode using index
const arr = new Float32Array(encoded.length);
for (let i = 0; i < arr.length; i += 1) {
arr[i] = index[encoded[i]];
}
// Dictionary of variables by name.
/** @type {Record<string, tf.Tensor>} */
const weights = {};

const weights = {};
offset = 0;
for (let i = 0; i < shapes.length; i += 1) {
const { shape } = shapes[i];
const size = shape.reduce((total, num) => total * num);
const values = arr.slice(offset, offset + size);
const tfarr = tf.tensor1d(values, 'float32');
weights[shapes[i].name] = tfarr.reshape(shape);
offset += size;
}
weightsCache[this.urlPath] = weights;
resolve(weights);
};
xhr.send(null);
// Create a tensor for each shape.
offset = 0;
shapes.forEach(({ shape, name }) => {
const size = shape.reduce((total, num) => total * num);
// Get the raw data.
const raw = encoded.slice(offset, offset + size);
// Decode using index.
const values = new Float32Array(raw.length);
raw.forEach((value, i) => {
values[i] = index[value];
});
weights[name] = tf.tensor(values, shape, 'float32');
offset += size;
});
return weights;
}
}

0 comments on commit 4e6a427

Please sign in to comment.