Tensorflow Lite从入门到精通( 八 )

更多细节请参考:Tensorflow官方文档训练后整型量化
float16量化现在,TensorFlow Lite 支持在模型从 TensorFlow 转换到 TensorFlow Lite FlatBuffer 格式期间将权重转换为 16 位浮点值 。这样可以将模型的大小缩减至原来的二分之一 。某些硬件(如 GPU)可以在这种精度降低的算术中以原生方式计算,从而实现比传统浮点执行更快的速度 。可以将 Tensorflow Lite GPU 委托配置为以这种方式运行 。但是,转换为 float16 权重的模型仍可在 CPU 上运行而无需其他修改:float16 权重会在首次推理前上采样为 float32 。这样可以在对延迟和准确率造成最小影响的情况下显著缩减模型大小 。
float16 量化的优点如下:

  • 将模型的大小缩减一半(因为所有权重都变成其原始大小的一半) 。
  • 实现最小的准确率损失 。
  • 支持可直接对 float16 数据进行运算的部分委托(例如 GPU 委托),从而使执行速度比 float32 计算更快 。
float16 量化的缺点如下:
  • 它不像对定点数学进行量化那样减少那么多延迟 。
  • 默认情况下,float16 量化模型在 CPU 上运行时会将权重值“反量化”为 float32 。(请注意,GPU 委托不会执行此反量化,因为它可以对 float16 数据进行运算)
import tensorflow as tfconverter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)converter.optimizations = [tf.lite.Optimize.DEFAULT]converter.target_spec.supported_types = [tf.float16]tflite_quant_model = converter.convert()在本教程中,您将从头开始训练一个 MNIST 模型,并在 TensorFlow 中检查其准确率,然后使用 float16 量化将此模型转换为 Tensorflow Lite FlatBuffer 格式 。最后,检查转换后模型的准确率,并将其与原始 float32 模型进行比较 。
Tensorflow Lite从入门到精通

文章插图
Tensorflow Lite从入门到精通

文章插图
# -*- coding:utf-8 -*-# Author:凌逆战 | Never# Date: 2022/10/12""""""import logginglogging.getLogger("tensorflow").setLevel(logging.DEBUG)import tensorflow as tffrom tensorflow import kerasimport numpy as npimport pathlib# Load MNIST datasetmnist = keras.datasets.mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# Normalize the input image so that each pixel value is between 0 to 1.train_images = train_images / 255.0test_images = test_images / 255.0# Define the model architecturemodel = keras.Sequential([    keras.layers.InputLayer(input_shape=(28, 28)),    keras.layers.Reshape(target_shape=(28, 28, 1)),    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),    keras.layers.MaxPooling2D(pool_size=(2, 2)),    keras.layers.Flatten(),    keras.layers.Dense(10)])# Train the digit classification modelmodel.compile(optimizer='adam',              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),              metrics=['accuracy'])model.fit(train_images, train_labels, epochs=1, validation_data=https://www.huyubaike.com/biancheng/(test_images, test_labels))converter = tf.lite.TFLiteConverter.from_keras_model(model)tflite_model = converter.convert()tflite_models_dir = pathlib.Path("./mnist_tflite_models/")tflite_models_dir.mkdir(exist_ok=True, parents=True)tflite_model_file = tflite_models_dir / "mnist_model.tflite"tflite_model_file.write_bytes(tflite_model)# float16 量化 ----------------------------------------------converter.optimizations = [tf.lite.Optimize.DEFAULT]converter.target_spec.supported_types = [tf.float16]tflite_fp16_model = converter.convert() # 转换为TFLitetflite_model_fp16_file = tflite_models_dir / "mnist_model_quant_f16.tflite"tflite_model_fp16_file.write_bytes(tflite_fp16_model)# 将模型加载到解释器中 ------------------------------# 没有量化的 TFliteinterpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))interpreter.allocate_tensors()# float16 量化的 TFLiteinterpreter_fp16 = tf.lite.Interpreter(model_path=str(tflite_model_fp16_file))interpreter_fp16.allocate_tensors()# 使用 "test" 数据集评估TF Lite模型def evaluate_model(interpreter): input_index = interpreter.get_input_details()[0]["index"] output_index = interpreter.get_output_details()[0]["index"] # 对“test”数据集中的每个图像进行预测 prediction_digits = [] for test_image in test_images: # 预处理: 添加batch维度并转换为float32以匹配模型的输入数据格式 test_image = np.expand_dims(test_image, axis=0).astype(np.float32) interpreter.set_tensor(input_index, test_image) interpreter.invoke() # 运行推理 # 后处理:去除batch维度 找到概率最高的数字 output = interpreter.tensor(output_index) digit = np.argmax(output()[0]) prediction_digits.append(digit) # 将预测结果与ground truth 标签进行比较,计算精度 。 accurate_count = 0 for index in range(len(prediction_digits)): if prediction_digits[index] == test_labels[index]: accurate_count += 1 accuracy = accurate_count * 1.0 / len(prediction_digits) return accuracyprint(evaluate_model(interpreter)) # 0.9662# NOTE: Colab运行在服务器的cpu上 。# 在写这篇文章的时候,TensorFlow Lite还没有超级优化的服务器CPU内核 。# 由于这个原因,它可能比上面的float interpreter要慢# 但是对于mobile CPUs,可以观察到相当大的加速print(evaluate_model(interpreter_fp16)) # 0.9662

经验总结扩展阅读