
文章插图

文章插图
# -*- coding:utf-8 -*-# Author:凌逆战 | Never# Date: 2022/10/12"""动态范围量化"""import logginglogging.getLogger("tensorflow").setLevel(logging.DEBUG)import tensorflow as tffrom tensorflow import kerasimport numpy as npimport pathlib# 加载MNIST数据集mnist = keras.datasets.mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 归一化输入图像,使每个像素值在0到1之间 。train_images = train_images / 255.0test_images = test_images / 255.0# 定义模型结构model = keras.Sequential([ keras.layers.InputLayer(input_shape=(28, 28)), keras.layers.Reshape(target_shape=(28, 28, 1)), keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu), keras.layers.MaxPooling2D(pool_size=(2, 2)), keras.layers.Flatten(), keras.layers.Dense(10)])# 训练数字分类模型model.compile(optimizer='adam', loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])model.fit(train_images, train_labels, epochs=1, validation_data=https://www.huyubaike.com/biancheng/(test_images, test_labels))# TF model to TFLiteconverter = tf.lite.TFLiteConverter.from_keras_model(model)tflite_model = converter.convert()tflite_models_dir = pathlib.Path("./mnist_tflite_models/")tflite_models_dir.mkdir(exist_ok=True, parents=True)tflite_model_file = tflite_models_dir / "mnist_model.tflite"tflite_model_file.write_bytes(tflite_model) # 84824# with open('model.tflite', 'wb') as f:# f.write(tflite_model)# 量化模型 ------------------------------------converter.optimizations = [tf.lite.Optimize.DEFAULT]tflite_quant_model = converter.convert()tflite_model_quant_file = tflite_models_dir / "mnist_model_quant.tflite"tflite_model_quant_file.write_bytes(tflite_quant_model) # 24072# 将模型加载到解释器中interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))interpreter.allocate_tensors() # 分配张量interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))interpreter_quant.allocate_tensors() # 分配张量# 在单个图像上测试模型test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)input_index = interpreter.get_input_details()[0]["index"]output_index = interpreter.get_output_details()[0]["index"]interpreter.set_tensor(input_index, test_image)interpreter.invoke()predictions = interpreter.get_tensor(output_index)print(predictions)# 使用“test”数据集评估TF Lite模型def evaluate_model(interpreter): input_index = interpreter.get_input_details()[0]["index"] output_index = interpreter.get_output_details()[0]["index"] # 对“test”数据集中的每个图像进行预测 prediction_digits = [] for test_image in test_images: # 预处理:添加batch维度并转换为float32以匹配模型的输入数据格式 。 test_image = np.expand_dims(test_image, axis=0).astype(np.float32) interpreter.set_tensor(input_index, test_image) interpreter.invoke() # 运行推理 # 后处理:去除批尺寸,找到概率最高的数字 output = interpreter.tensor(output_index) digit = np.argmax(output()[0]) prediction_digits.append(digit) # 将预测结果与ground truth 标签进行比较,计算精度 。 accurate_count = 0 for index in range(len(prediction_digits)): if prediction_digits[index] == test_labels[index]: accurate_count += 1 accuracy = accurate_count * 1.0 / len(prediction_digits) return accuracyprint(evaluate_model(interpreter)) # 0.958print(evaluate_model(interpreter_quant)) # 0.9579
经验总结扩展阅读
- 幼师好还是护士好 哪个更吃香
- 2023年2月6日是买衣服的黄道吉日吗 2023年2月6日买衣服黄道吉日
- 2023年2月6日适合制作寿衣吗 2023年2月6日是制作寿衣吉日吗
- 2023年2月6日买鸡黄道吉日 2023年2月6日买鸡行吗
- 2023年2月6日买牛好不好 2023年2月6日买牛吉日一览表
- 2023年2月6日画画好不好 2023年农历正月十六画画吉日
- 2023年10月1日入学行吗 2023年农历八月十七入学吉日
- 2023年10月1日是举办成人仪式的黄道吉日吗 2023年10月1日举办成人仪式吉日一览表
- 2023年10月1日上学好不好 2023年10月1日适合上学吗
- 2023年10月1日清扫房屋好吗 2023年农历八月十七宜清扫房屋吗