After Pytorch Post training quantization, I find that the forward propagation of the quantized model still seems to use dequantized float32 weights, rather than using quantized int8. Below I attached the PTQ example given on the Pytorch quantization documentation. I used hook to see their result in forward propagation. I find that if I use dequantized float tensors to manually do the calculation myself, I get the same result as their forward propagation output that supposedly should be using int8. Feel free to try to code yourself.
How does Pytorch PTQ do forward propagation? Do they use float32 or int8 weights? What does it mean to be quantized if it still need to store and use float32 values for inference.
import torch
# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x
# create a model instance
model_fp32 = M()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
# Convert the observed model to a quantized model.
model_int8 = torch.quantization.convert(model_fp32_prepared)
# hooks to retrieve inputs, outputs and weights of conv layer (fused conv + relu)
conv_inputs = []
conv_weights = []
conv_outputs = []
hooks = []
for hook in hooks:
hook.remove()
def hook_fn(m, i, o):
global conv_inputs, conv_outputs, conv_weights
conv_inputs = i[0] # [0] because conv_inputs is a tuple and we only care about the first item
conv_weights = m.weight()
conv_outputs = o
hooks.append(model_int8.conv.register_forward_hook(hook_fn))
# run forward pass
res = model_int8(input_fp32)
relu = torch.nn.ReLU()
# Manually dequantize and manually compute output
# Note that convolution is just a simple multiplication because it's a 1x1 kernel with 1 channel
conv_float_input = (conv_inputs.int_repr().int() - conv_inputs.q_zero_point())*conv_inputs.q_scale()
conv_float_weight = conv_weights.int_repr() * conv_weights.q_per_channel_scales()
conv_float_output = conv_float_input * conv_float_weight + model_int8.conv.bias()
manual_output_1 = (relu(conv_float_output / model_int8.conv.scale)).round()
print("manual_output_1:\n", manual_output_1)
# Use built-in dequantize() and manually compute output
# Note that convolution is just a simple multiplication because it's a 1x1 kernel with 1 channel
conv_float_output = conv_inputs.dequantize()*conv_weights.dequantize() + model_int8.conv.bias()
manual_output_2 = relu((conv_float_output / model_int8.conv.scale).round())
print("manual_output_2:\n", manual_output_2)
# Output produced by forward() pass
print("output produced by forward():\n", conv_outputs.int_repr())
# print the difference between manual input and output generated by forward (0.0 means no difference)
print("manual_output_1 and conv_outputs differ by: ", end = "")
print((manual_output_1 - conv_outputs.int_repr()).sum().item())
print("manual_output_2 and conv_outputs differ by: ", end = "")
print((manual_output_2 - conv_outputs.int_repr()).sum().item())