Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bool input tensor #771

Open
bywmm opened this issue Apr 19, 2023 · 7 comments
Open

Bool input tensor #771

bywmm opened this issue Apr 19, 2023 · 7 comments

Comments

@bywmm
Copy link

bywmm commented Apr 19, 2023

Hi all,
I’m using larq to benchmark the real performance of my bnn. It’s very convenient, but I’m having some trouble. I use a toy model as an example.

X_in = Input(shape=(1, 1, 1024,), batch_size=1024)
X_in_quantized = lq.quantizers.SteSign()(X_in)

X = lq.layers.QuantConv2D(64, kernel_size=(1, 1),
                          kernel_quantizer="ste_sign",
                          kernel_constraint="weight_clip",
                          )(X_in_quantized)
out = tf.keras.layers.Reshape((-1,))(X)

toy_model = Model(inputs=X_in, outputs=out)

This model only has a QuantDense Layer. Then, I benchmark it on Raspberry Pi 4B (64bit-OS).

STARTING!
Log parameter values verbosely: [0]
Min num runs: [50]
Num threads: [1]
Graph: [toy1.tflite]
#threads used for CPU inference: [1]
Loaded model toy1.tflite
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
The input model file size (MB): 0.010452
Initialized session in 3.889ms.
Running benchmark for at least 1 iterations and at least 0.5 seconds but terminate if exceeding 150 seconds.
count=121 first=4357 curr=2186 min=2143 max=4357 avg=2277.17 std=302

Running benchmark for at least 50 iterations and at least 1 seconds but terminate if exceeding 150 seconds.
count=253 first=2206 curr=2179 min=2141 max=2248 avg=2168.41 std=18

Inference timings in us: Init: 3889, First inference: 4357, Warmup (avg): 2277.17, Inference (avg): 2168.41
Note: as the benchmark tool itself affects memory footprint, the following is only APPROXIMATE to the actual memory footprint of the model at runtime. Take the information at your discretion.
Memory footprint delta from the start of the tool (MB): init=0.0507812 overall=10.5469

I think 10.5469 MB is the used memory size, containing the floating-point input tensor. I am wondering how to use a quantized (binarized) tensor as input for saving memory. It’s actually a bool input tensor. How can I test the memory consumption with the bool input tensor?

# Some like 
X_in = Input(shape=(1, 1, 1024,), batch_size=1024, dtype=bool)
# instead of
X_in = Input(shape=(1, 1, 1024,), batch_size=1024)
X_in_quantized = lq.quantizers.SteSign()(X_in)

If you can give me some hints, I would be very grateful.

@Tombana
Copy link
Collaborator

Tombana commented Apr 19, 2023

I think the MLIR converter can not directly create .tflite files with boolean input or output tensors; only float32 and int8 are supported. Even if booleans were supported, then every bool would still take up 1 byte instead of 1 bit.

The LCE converter does have a utility function for getting bitpacked boolean output tensors: first you have to create a regular tflite file that ends with a lq.quantizers.SteSign() as output. Then you convert it as normal and get a float32 output tensor. (You can verify that in netron). Then you can call the following utility function:

from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops

strip_lcedequantize_ops(tflite_file_bytes)

That should result in a tflite file that has an int32 output tensor (again its a good idea to verify it in netron). It does not actually use 32-bit integers though: these numbers represent bitpacked booleans where every integer contains 32 booleans.

@bywmm
Copy link
Author

bywmm commented Apr 19, 2023

Thanks for your reply. It helps me a lot.

Besides, how can I use int8 input in LCE benchmark? When I directly modify the dtype of input tensor, it raises a TypeError.

X_in = Input(shape=(1, 1, 1024,), batch_size=1024, dtype=tf.int8)
TypeError: Value passed to parameter 'x' has DataType int8 not in list of allowed values: bfloat16, float16, float32, float64, int32, int64, complex64, complex128

@Tombana
Copy link
Collaborator

Tombana commented Apr 19, 2023

For int8 tflite files you have to do either int8 quantization-aware-training or use post-training quantization. The tensor in Tensorflow/Keras stays float32 and during conversion it becomes int8.

@bywmm
Copy link
Author

bywmm commented Apr 19, 2023

I get it. Thank you!

@bywmm bywmm closed this as completed Apr 19, 2023
@bywmm
Copy link
Author

bywmm commented Apr 26, 2023

Hi @Tombana , sorry for reopening this issue. After learning more about larq compute engine, I have two new questions.
As discussed above,

X_in = Input(shape=(1, 1, 1024), batch_size=1024, dtype=tf.float32)
X_in_quantized = lq.quantizers.SteSign()(X_in)

I want X_in_quantized as Input.
Q1. When I print X_in_quantized in python, I found X_in_quantized is still a float tensor with the shape of [1, 1, 1024]. But after converting the tf model into a tflite model, it becomes a bitpacked tensor. Can you give me an overview of how this is done?

Q2. Is there any possibility to add a custom operation, e.g., costom_converter(), to convert the `tf.float32' input tensor to a 'tf.int32' bitpacked input tensor?

X_in = Input(shape=(1, 1, 32), batch_size=1024, dtype=tf.float32)
X_in_quantized = costom_converter()(X_in)

If the answer is yes, how can I add a custom operation and build it.

I'm not familiar with tflite and mlir, so I'd appreciate it if you could go into more detail.

Thank you!

@bywmm bywmm reopened this Apr 26, 2023
@Tombana
Copy link
Collaborator

Tombana commented Apr 26, 2023

Q1. When I print X_in_quantized in python, I found X_in_quantized is still a float tensor with the shape of [1, 1, 1024]. But after converting the tf model into a tflite model, it becomes a bitpacked tensor. Can you give me an overview of how this is done?

This is correct. During training everything is float so that you can compute gradients. Once you convert the model into a tflite model for inference, it becomes a bitpacked tensor. What exactly is your question about this? This conversion between the models happens in the MLIR converter. In the tflite model, there is an operator LceQuantize which takes a float or int8 tensor as input, and outputs the bitpacked tensor. At inference time it does that by checking if (x[i] < 0) for each input value, and then sets the i-th bit to 0 or 1.

Q2. Is there any possibility to add a custom operation, e.g., costom_converter(), to convert the `tf.float32' input tensor to a 'tf.int32' bitpacked input tensor?

X_in = Input(shape=(1, 1, 32), batch_size=1024, dtype=tf.float32)
X_in_quantized = costom_converter()(X_in)

If the answer is yes, how can I add a custom operation and build it.

If you want the tf.int32 type during training, I'm not sure if that's possible. I also don't understand what the purpose of that would be: the TF ops don't understand this bitpacked format so it would be a bit useless.
If you want the tf.int32 type during inference in TFLite and you want to remove the int8 or float32 tensor that comes before LceQuantize, then you can maybe do that by extending the function strip_lcedequantize_ops from larq_compute_engine.mlir.python.util.

@bywmm
Copy link
Author

bywmm commented May 6, 2023

@Tombana , thanks for your suggestion! I have successfully removed the float32 input tensor by extending the strip_lcedequantize_ops function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants