Skip to content

Commit

Permalink
Remove attribute type
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed May 20, 2024
1 parent b91e72f commit 5028ef6
Show file tree
Hide file tree
Showing 42 changed files with 1,687 additions and 3,001 deletions.
37 changes: 16 additions & 21 deletions source/armnn.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,32 @@ armnn.Node = class {
this.name = base.layerName;
const inputs = [...base.inputSlots];
while (inputs.length > 0) {
const inputSchema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: '?' };
const count = inputSchema.list ? inputs.length : 1;
const argument = new armnn.Argument(inputSchema.name, inputs.splice(0, count).map((inputSlot) => {
const schema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: '?' };
const count = schema.list ? inputs.length : 1;
const argument = new armnn.Argument(schema.name, inputs.splice(0, count).map((inputSlot) => {
return value(inputSlot.connection.sourceLayerIndex, inputSlot.connection.outputSlotIndex);
}));
this.inputs.push(argument);
}
const outputs = [...base.outputSlots];
while (outputs.length > 0) {
const outputSchema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: '?' };
const count = outputSchema.list ? outputs.length : 1;
this.outputs.push(new armnn.Argument(outputSchema.name, outputs.splice(0, count).map((outputSlot) => {
const schema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: '?' };
const count = schema.list ? outputs.length : 1;
this.outputs.push(new armnn.Argument(schema.name, outputs.splice(0, count).map((outputSlot) => {
return value(base.index, outputSlot.index);
})));
}
}
if (layer.layer) {
if (layer.layer.descriptor && this.type.attributes) {
for (const [name, value] of Object.entries(layer.layer.descriptor)) {
const attribute = new armnn.Attribute(metadata.attribute(type, name), name, value);
for (const [key, obj] of Object.entries(layer.layer.descriptor)) {
const schema = metadata.attribute(name, key);
const type = schema ? schema.type : null;
let value = ArrayBuffer.isView(obj) ? Array.from(obj) : obj;
if (armnn.schema[type]) {
value = armnn.Utility.enum(type, value);
}
const attribute = new armnn.Argument(key, value, type);
this.attributes.push(attribute);
}
}
Expand All @@ -198,23 +204,12 @@ armnn.Node = class {
}
};

armnn.Attribute = class {

constructor(metadata, name, value) {
this.name = name;
this.type = metadata ? metadata.type : null;
this.value = ArrayBuffer.isView(value) ? Array.from(value) : value;
if (armnn.schema[this.type]) {
this.value = armnn.Utility.enum(this.type, this.value);
}
}
};

armnn.Argument = class {

constructor(name, value) {
constructor(name, value, type) {
this.name = name;
this.value = value;
this.type = type || null;
}
};

Expand Down
49 changes: 19 additions & 30 deletions source/barracuda.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ barracuda.Graph = class {

barracuda.Argument = class {

constructor(name, value) {
constructor(name, value, type) {
this.name = name;
this.value = value;
this.type = type || null;
}
};

Expand Down Expand Up @@ -134,37 +135,25 @@ barracuda.Node = class {
const node = new barracuda.Node(metadata, {}, { name: type, category: 'Activation' }, values);
this.chain = [node];
}
const attribute = (name, type, value, defaultValue) => {
if (value === undefined) {
return;
}
if (Array.isArray(defaultValue) && Array.isArray(value) && value.length === defaultValue.length && value.every((v, i) => v === defaultValue[i])) {
return;
}
if (typeof defaultValue === 'function' && defaultValue(value)) {
return;
const attributes = [
['strides', 'int32[]', []],
['pads', 'int32[]', (value) => Array.isArray(value) && (value.every((v) => v === 0) || value.every((v) => v === -1))],
['pool_size', 'int32[]', []],
['alpha', 'float32', 1],
['beta', 'float32', 0],
['axis', 'int32', -1]
];
for (const [name, type, defaultValue] of attributes) {
const value = layer[name];
if ((value === undefined) ||
(Array.isArray(defaultValue) && Array.isArray(value) && value.length === defaultValue.length && value.every((v, i) => v === defaultValue[i])) ||
(typeof defaultValue === 'function' && defaultValue(value)) ||
(defaultValue === value)) {
continue;
}
if (defaultValue === value) {
return;
}
const attribute = new barracuda.Attribute(name, type, value);
const attribute = new barracuda.Argument(name, value, type);
this.attributes.push(attribute);
};
attribute('strides', 'int32[]', layer.strides, []);
attribute('pads', 'int32[]', layer.pads, (value) => Array.isArray(value) && (value.every((v) => v === 0) || value.every((v) => v === -1)));
attribute('size', 'int32[]', layer.pool_size, []);
attribute('alpha', 'float32', layer.alpha, 1);
attribute('beta', 'float32', layer.beta, 0);
attribute('axis', 'int32', layer.axis, -1);
}
};

barracuda.Attribute = class {

constructor(name, type, value) {
this.name = name;
this.type = type;
this.value = value;
}
}
};

Expand Down
169 changes: 83 additions & 86 deletions source/bigdl.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ bigdl.Graph = class {

bigdl.Argument = class {

constructor(name, value) {
constructor(name, value, type) {
this.name = name;
this.value = value;
this.type = type || null;
}
};

Expand Down Expand Up @@ -132,107 +133,103 @@ bigdl.Node = class {
]));
}
}
for (const [key, value] of Object.entries(module.attr)) {
for (const [key, obj] of Object.entries(module.attr)) {
if (key === 'module_numerics' || key === 'module_tags') {
continue;
}
if (value.dataType === bigdl.proto.DataType.TENSOR) {
if (value.value) {
this.inputs.push(new bigdl.Argument(key, [new bigdl.Value('', null, new bigdl.Tensor(value.tensorValue, tensors))]));
if (obj.dataType === bigdl.proto.DataType.TENSOR) {
if (obj.value) {
this.inputs.push(new bigdl.Argument(key, [new bigdl.Value('', null, new bigdl.Tensor(obj.tensorValue, tensors))]));
}
continue;
}
if (value.dataType === bigdl.proto.DataType.REGULARIZER && value.value === undefined) {
if (obj.dataType === bigdl.proto.DataType.REGULARIZER && obj.value === undefined) {
continue;
}
if (value.dataType === bigdl.proto.DataType.ARRAY_VALUE && value.arrayValue.datatype === bigdl.proto.DataType.TENSOR) {
this.inputs.push(new bigdl.Argument(key, value.arrayValue.tensor.map((tensor) => new bigdl.Value('', null, new bigdl.Tensor(tensor, tensors)))));
if (obj.dataType === bigdl.proto.DataType.ARRAY_VALUE && obj.arrayValue.datatype === bigdl.proto.DataType.TENSOR) {
this.inputs.push(new bigdl.Argument(key, obj.arrayValue.tensor.map((tensor) => new bigdl.Value('', null, new bigdl.Tensor(tensor, tensors)))));
continue;
}
this.attributes.push(new bigdl.Attribute(key, value));
}
const output = this.name || this.type + module.namePostfix;
this.outputs.push(new bigdl.Argument('output', [values.map(output)]));
}
};

bigdl.Attribute = class {

constructor(name, value) {
this.name = name;
switch (value.dataType) {
case bigdl.proto.DataType.INT32: {
this.type = 'int32';
this.value = value.int32Value;
break;
}
case bigdl.proto.DataType.FLOAT: {
this.type = 'float32';
this.value = value.floatValue;
break;
}
case bigdl.proto.DataType.DOUBLE: {
this.type = 'float64';
this.value = value.doubleValue;
break;
}
case bigdl.proto.DataType.BOOL: {
this.type = 'boolean';
this.value = value.boolValue;
break;
}
case bigdl.proto.DataType.REGULARIZER: {
this.value = value.value;
break;
}
case bigdl.proto.DataType.MODULE: {
this.value = value.bigDLModule;
break;
}
case bigdl.proto.DataType.NAME_ATTR_LIST: {
this.value = value.nameAttrListValue;
break;
}
case bigdl.proto.DataType.ARRAY_VALUE: {
switch (value.arrayValue.datatype) {
case bigdl.proto.DataType.INT32: {
this.type = 'int32[]';
this.value = value.arrayValue.i32;
break;
}
case bigdl.proto.DataType.FLOAT: {
this.type = 'float32[]';
this.value = value.arrayValue.flt;
break;
}
case bigdl.proto.DataType.STRING: {
this.type = 'string[]';
this.value = value.arrayValue.str;
break;
}
case bigdl.proto.DataType.TENSOR: {
this.type = 'tensor[]';
this.value = value.arrayValue.tensor;
break;
let type = null;
let value = null;
switch (obj.dataType) {
case bigdl.proto.DataType.INT32: {
type = 'int32';
value = obj.int32Value;
break;
}
case bigdl.proto.DataType.FLOAT: {
type = 'float32';
value = obj.floatValue;
break;
}
case bigdl.proto.DataType.DOUBLE: {
type = 'float64';
value = obj.doubleValue;
break;
}
case bigdl.proto.DataType.BOOL: {
type = 'boolean';
value = obj.boolValue;
break;
}
case bigdl.proto.DataType.REGULARIZER: {
value = obj.value;
break;
}
case bigdl.proto.DataType.MODULE: {
value = obj.bigDLModule;
break;
}
case bigdl.proto.DataType.NAME_ATTR_LIST: {
value = value.nameAttrListValue;
break;
}
case bigdl.proto.DataType.ARRAY_VALUE: {
switch (obj.arrayValue.datatype) {
case bigdl.proto.DataType.INT32: {
type = 'int32[]';
value = obj.arrayValue.i32;
break;
}
case bigdl.proto.DataType.FLOAT: {
type = 'float32[]';
value = obj.arrayValue.flt;
break;
}
case bigdl.proto.DataType.STRING: {
type = 'string[]';
value = obj.arrayValue.str;
break;
}
case bigdl.proto.DataType.TENSOR: {
type = 'tensor[]';
value = obj.arrayValue.tensor;
break;
}
default: {
throw new bigdl.Error(`Unsupported attribute array data type '${obj.arrayValue.datatype}'.`);
}
}
default: {
throw new bigdl.Error(`Unsupported attribute array data type '${value.arrayValue.datatype}'.`);
break;
}
case bigdl.proto.DataType.DATA_FORMAT: {
switch (obj.dataFormatValue) {
case 0: value = 'NCHW'; break;
case 1: value = 'NHWC'; break;
default: throw new bigdl.Error(`Unsupported data format '${obj.dataFormatValue}'.`);
}
break;
}
break;
}
case bigdl.proto.DataType.DATA_FORMAT: {
switch (value.dataFormatValue) {
case 0: this.value = 'NCHW'; break;
case 1: this.value = 'NHWC'; break;
default: throw new bigdl.Error(`Unsupported data format '${value.dataFormatValue}'.`);
default: {
throw new bigdl.Error(`Unsupported attribute data type '${obj.dataType}'.`);
}
break;
}
default: {
throw new bigdl.Error(`Unsupported attribute data type '${value.dataType}'.`);
}
const argument = new bigdl.Argument(key, value, type);
this.attributes.push(argument);
}
const output = this.name || this.type + module.namePostfix;
this.outputs.push(new bigdl.Argument('output', [values.map(output)]));
}
};

Expand Down
Loading

0 comments on commit 5028ef6

Please sign in to comment.