【1 精度对齐】tf2.x与pytorch模型精度对齐
-
今天为大家带来tf2.x与pytorch部分模型精度对齐
- 1、
nn.Linear vs layers.Dense
- 2、
nn.Conv1d vs layers.Conv1D
- 3、
nn.Embedding vs layers.Embedding
- 4、
nn.GRU vs layers.GRU
导包+准备工具
import torch import torch.nn as nn import torch.nn.functional as F import tensorflow as tf from tensorflow import keras import numpy as np import math # 动态增加TF的GPU显存 gpus = tf.config.experimental.list_physical_devices(device_type='GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(device=gpu, enable=True) def reoder_process(x): reoder_index = [4,5,6,7,0,1,2,3,8,9,10,11] if len(x.shape)==2: return x[:,reoder_index] else: return x[reoder_index] def pad(x, kernel_size=3, dilation=1): """For stride=1 or stride = 2 or stride = 3""" pad_total = dilation * (kernel_size - 1) pad_beg = pad_total // 2 pad_end = pad_total - pad_beg x_padded = F.pad( x, pad=(pad_beg, pad_end, pad_beg, pad_end)) return x_padded def compare_difference(a,b): o = np.abs((a.detach().numpy()-b.numpy())).max() print(f"max diffenence is {o}") o = np.abs((a.detach().numpy()-b.numpy())).mean() print(f"mean diffenence is {o}")
1、nn.Linear vs layers.Dense
# 创建模型 pt_linear = nn.Linear(2,4) tf_dense = keras.layers.Dense(4) # 如果不build就没有权重 tf_dense.build(input_shape=(None,2)) # 复制权重 weight = pt_linear.weight.data.T.numpy() bias = pt_linear.bias.data.numpy() tf_dense.set_weights(weights=[weight,bias]) # 比较 x = np.random.randn(2,2).astype(np.float32) pt_x = torch.from_numpy(x) tf_x = tf.constant(x) a = pt_linear(pt_x) b = tf_dense(tf_x) compare_difference(a,b) # max diffenence is 0.0 # meandiffenence is 0.0
2、nn.Conv1d vs layers.Conv1D
# 创建模型 pt_conv1d = nn.Conv1d(in_channels=3,out_channels=2,kernel_size=1) tf_conv1d = keras.layers.Conv1D(filters=2,kernel_size=1) tf_conv1d.build(input_shape=(None,2,3)) # pt的参数 weight shape [2,3,1] # tf的参数 weight shape [1,3,2] # pt的输入格式 shape [bs,in_channels,h] # tf的输入格式 shape [bs,h,in_channels] # 复制参数 weight = pt_conv1d.weight.data.T.numpy() bias = pt_conv1d.bias.data.numpy() tf_conv1d.set_weights(weights=[weight,bias]) # 比较 x = np.random.randn(1,2,3).astype(np.float32) pt_x = torch.from_numpy(x).transpose(1,2) tf_x = tf.constant(x) a=pt_conv1d(pt_x).transpose(1,2) b=tf_conv1d(tf_x) compare_difference(a,b) # max diffenence is 0.0 # mean diffenence is 0.0
3、nn.Embedding vs layers.Embedding
# 创建模型 pt_embedding = nn.Embedding(num_embeddings=3,embedding_dim=4) tf_embedding = keras.layers.Embedding(input_dim=3,output_dim=4,mask_zero=False) tf_embedding.build(input_shape=(None,3)) # 复制权重 weight = pt_embedding.weight.data.numpy() tf_embedding.set_weights(weights=[weight]) # 比较 x=np.array([0,1,2]).astype(np.int64) pt_x = torch.from_numpy(x) tf_x = tf.constant(x) a=pt_embedding(pt_x) b=tf_embedding(tf_x) compare_difference(a,b) # max diffenence is 0.0 # mean diffenence is 0.0
4、nn.GRU vs layers.GRU
# 创建模型 pt_gru = nn.GRU(input_size=2,hidden_size=4,batch_first=True,num_layers=1,bidirectional=False) tf_gru = keras.layers.GRU(units=4,return_sequences=True,return_state=True) tf_gru.build(input_shape=(None,3,2)) # pt gru的权重格式是 r,z,h # tf gru的权重格式是 z,r,h # 两个权重格式不同!因此需要reoder权重! # 复制参数 input_kernel = reoder_process(pt_gru.weight_ih_l0.T.data.numpy()) recur_kernel = reoder_process(pt_gru.weight_hh_l0.T.data.numpy()) bias = torch.stack([reoder_process(pt_gru.bias_ih_l0.data),reoder_process(pt_gru.bias_hh_l0.data)]).numpy() tf_gru.set_weights(weights=[input_kernel,recur_kernel,bias]) # 比较 x = np.random.randn(1,3,2).astype(np.float32) pt_x = torch.from_numpy(x) tf_x = tf.constant(x) pt_outputs,pt_hidden_states = pt_gru(pt_x) tf_outputs,tf_hidden_states = tf_gru(tf_x) compare_difference(pt_outputs,tf_outputs) # max diffenence is 5.960464477539063e-08 # mean diffenence is 1.7210064484629584e-08 compare_difference(pt_hidden_states ,tf_hidden_states ) # max diffenence is 2.9802322387695312e-08 # mean diffenence is 7.8580342233181e-09
To be continued
- GRU bidirectional
- BatchNorm
- LayerNorm
- Conv2d valid padding
- Conv2d same padding
- 1、
-
183****0229
-
183****0229
-
183****0229
-
-
精度对齐下篇,请移步至【2 精度对齐】tf2.x与pytorch模型精度对齐
包含:
5、GRU bidirectional
6、nn.BatchNorm1d vs layers.BatchNormalization
7、nn.LayerNorm vs layers.LayerNormalization
8、Conv2d valid padding
9、Conv2d same padding