未来 , TensorFlow Lite 将降低剪枝后模型的延迟 。
聚类:聚类的工作原理是将模型中每一层的权重归入预定数量的聚类中,然后共享属于每个单独聚类的权重的质心值 。这就减少了模型中唯一权重值的数量,从而降低了其复杂性 。这样一来,就可以更高效地压缩聚类后的模型,从而提供类似于剪枝的部署优势 。
开发工作流程
- 首先,检查托管模型中的模型能否用于您的应用 。如果不能,建议从训练后量化工具开始,因为它适用范围广,且无需训练数据 。
- 对于无法达到准确率和延迟目标 , 或硬件加速器支持很重要的情况,量化感知训练是更好的选择 。请参阅TensorFlow Model Optimization Toolkit下的其他优化技术 。
- 如果要进一步缩减模型大小,可以在量化模型之前尝试剪枝和/或聚类 。
量化感知训练等我将代码进行了训练后量化后,再来整这个量化感知训练
训练后量化量化的工作原理是降低模型参数的精度(默认情况为 32 位浮点数) 。这样可以获得较小的模型大小和较快的计算速度 。TensorFlow Lite 提供以下量化类型:
技术数据要求大小缩减准确率训练后 Float16 量化无数据高达 50%轻微的准确率损失训练后 动态范围量化无数据高达 75%,速度加快 2-3倍极小的准确率损失训练后 int8 量化无标签的代表性样本高达 75%,速度加快3+倍极小的准确率损失量化感知训练带标签的训练数据高达 75%极小的准确率损失以下决策树可帮助您仅根据预期的模型大小和准确率来选择要用于模型的量化方案 。
文章插图
动态范围量化权重(float32) 会在训练后量化为 整型(int8),激活会在推断时动态量化,模型大小缩减至原来的四分之一:
TFLite 支持对激活进行动态量化(激活始终以浮点进行存储 。对于支持量化内核的算子,激活会在处理前动态量化为 8 位精度,并在处理后反量化为浮点精度 。根据被转换的模型 , 这可以提供比纯浮点计算更快的速度)以实现以下效果:
- 在可用时使用量化内核加快实现速度 。
- 将计算图不同部分的浮点内核与量化内核混合 。
我们来吃一个完整的栗子,构建一个MNIST模型,并且对它使用动态范围量化,对比量化前后的精度变化:
文章插图
文章插图
# -*- coding:utf-8 -*-# Author:凌逆战 | Never# Date: 2022/10/12"""动态范围量化"""import logginglogging.getLogger("tensorflow").setLevel(logging.DEBUG)import tensorflow as tffrom tensorflow import kerasimport numpy as npimport pathlib# 加载MNIST数据集mnist = keras.datasets.mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 归一化输入图像,使每个像素值在0到1之间 。train_images = train_images / 255.0test_images = test_images / 255.0# 定义模型结构model = keras.Sequential([ keras.layers.InputLayer(input_shape=(28, 28)), keras.layers.Reshape(target_shape=(28, 28, 1)), keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu), keras.layers.MaxPooling2D(pool_size=(2, 2)), keras.layers.Flatten(), keras.layers.Dense(10)])# 训练数字分类模型model.compile(optimizer='adam', loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])model.fit(train_images, train_labels, epochs=1, validation_data=https://www.huyubaike.com/biancheng/(test_images, test_labels))# TF model to TFLiteconverter = tf.lite.TFLiteConverter.from_keras_model(model)tflite_model = converter.convert()tflite_models_dir = pathlib.Path("./mnist_tflite_models/")tflite_models_dir.mkdir(exist_ok=True, parents=True)tflite_model_file = tflite_models_dir / "mnist_model.tflite"tflite_model_file.write_bytes(tflite_model) # 84824# with open('model.tflite', 'wb') as f:# f.write(tflite_model)# 量化模型 ------------------------------------converter.optimizations = [tf.lite.Optimize.DEFAULT]tflite_quant_model = converter.convert()tflite_model_quant_file = tflite_models_dir / "mnist_model_quant.tflite"tflite_model_quant_file.write_bytes(tflite_quant_model) # 24072# 将模型加载到解释器中interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))interpreter.allocate_tensors() # 分配张量interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))interpreter_quant.allocate_tensors() # 分配张量# 在单个图像上测试模型test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)input_index = interpreter.get_input_details()[0]["index"]output_index = interpreter.get_output_details()[0]["index"]interpreter.set_tensor(input_index, test_image)interpreter.invoke()predictions = interpreter.get_tensor(output_index)print(predictions)# 使用“test”数据集评估TF Lite模型def evaluate_model(interpreter): input_index = interpreter.get_input_details()[0]["index"] output_index = interpreter.get_output_details()[0]["index"] # 对“test”数据集中的每个图像进行预测 prediction_digits = [] for test_image in test_images: # 预处理:添加batch维度并转换为float32以匹配模型的输入数据格式 。 test_image = np.expand_dims(test_image, axis=0).astype(np.float32) interpreter.set_tensor(input_index, test_image) interpreter.invoke() # 运行推理 # 后处理:去除批尺寸,找到概率最高的数字 output = interpreter.tensor(output_index) digit = np.argmax(output()[0]) prediction_digits.append(digit) # 将预测结果与ground truth 标签进行比较,计算精度 。 accurate_count = 0 for index in range(len(prediction_digits)): if prediction_digits[index] == test_labels[index]: accurate_count += 1 accuracy = accurate_count * 1.0 / len(prediction_digits) return accuracyprint(evaluate_model(interpreter)) # 0.958print(evaluate_model(interpreter_quant)) # 0.9579
推荐阅读
- 请问去未知暗殿怎么走(复古未知暗殿从哪里进图解)
- 神州风闻录最新侍从强度排行是怎么样的
- iphone13几个摄像头_iphone13四个摄像头吗
- DevOps|从特拉斯辞职风波到研发效能中的不靠谱人干的荒唐事
- iqoo9参数详细_iqoo9参数配置
- 小米11 Lite参数配置_小米11 Lite参数详情
- 花心男人面相,可以从3个方面进行分析,助力快速了解
- 宝宝耳朵面相怎么看,可以从这4个方面来看
- fps值太低怎么办(fps从100多突然变低到10几)
- fps太低怎么办(fps从100多突然变低到10几)