ndarray numpy和tensor(GPU上的numpy)速查

类型(Types)NumpyPyTorchnp.ndarraytorch.Tensornp.float32torch.float32; torch.floatnp.float64torch.float64; torch.doublenp.floattorch.float16; torch.halfnp.int8torch.int8np.uint8torch.uint8np.int16torch.int16; torch.shortnp.int32torch.int32; torch.intnp.int64torch.int64; torch.long构造器(Constructor)零和一(Ones and zeros)NumpyPyTorchnp.empty((2, 3))torch.empty(2, 3)np.empty_like(x)torch.empty_like(x)np.eyetorch.eyenp.identitytorch.eyenp.onestorch.onesnp.ones_liketorch.ones_likenp.zerostorch.zerosnp.zeros_liketorch.zeros_like从已知数据构造NumpyPyTorchnp.array([[1, 2], [3, 4]])torch.tensor([[1, 2], [3, 4]])np.array([3.2, 4.3], dtype=np.float16)np.float16([3.2, 4.3])torch.tensor([3.2, 4.3], dtype=torch.float16)x.copy()x.clone()np.fromfile(file)torch.tensor(torch.Storage(file))np.frombuffernp.fromfunctionnp.fromiternp.fromstringnp.loadtorch.loadnp.loadtxtnp.concatenatetorch.cat数值范围NumpyPyTorchnp.arange(10)torch.arange(10)np.arange(2, 3, 0.1)torch.arange(2, 3, 0.1)np.linspacetorch.linspacenp.logspacetorch.logspace构造矩阵NumpyPyTorchnp.diagtorch.diagnp.triltorch.trilnp.triutorch.triu参数NumpyPyTorchx.shapex.shapex.stridesx.stride()x.ndimx.dim()x.datax.datax.sizex.nelement()x.dtypex.dtype索引NumpyPyTorchx[0]x[0]x[:, 0]x[:, 0]x[indices]x[indices]np.take(x, indices)torch.take(x, torch.LongTensor(indices))x[x != 0]x[x != 0]形状(Shape)变换NumpyPyTorchx.reshapex.reshape; x.viewx.resize()x.resize_nullx.resize_as_x.transposex.transpose or x.permutex.flattenx.view(-1)x.squeeze()x.squeeze()x[:, np.newaxis]; np.expand_dims(x, 1)x.unsqueeze(1)数据选择NumpyPyTorchnp.putx.putx.put_x = np.array([1, 2, 3])x.repeat(2) # [1, 1, 2, 2, 3, 3]x = torch.tensor([1, 2, 3])x.repeat(2) # [1, 2, 3, 1, 2, 3]x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3]np.tile(x, (3, 2))x.repeat(3, 2)np.choosenp.sortsorted, indices = torch.sort(x, [dim])np.argsortsorted, indices = torch.sort(x, [dim])np.nonzerotorch.nonzeronp.wheretorch.wherex[::-1]数值计算NumpyPyTorchx.minx.minx.argminx.argminx.maxx.maxx.argmaxx.argmaxx.clipx.clampx.roundx.roundnp.floor(x)torch.floor(x); x.floor()np.ceil(x)torch.ceil(x); x.ceil()x.tracex.tracex.sumx.sumx.cumsumx.cumsumx.meanx.meanx.stdx.stdx.prodx.prodx.cumprodx.cumprodx.all(x == 1).sum() == x.nelement()x.any(x == 1).sum() > 0数值比较NumpyPyTorchnp.lessx.ltnp.less_equalx.lenp.greaterx.gtnp.greater_equalx.genp.equalx.eqnp.not_equalx.nepytorch与tensorflow API速查表方法名称pytrochtensorflownumpy裁剪torch.clamp(x, min, max)tf.clip_by_value(x, min, max)np.clip(x, min, max)取最小值torch.min(x, dim)[0]tf.min(x, axis)np.min(x , axis)取两个tensor的最大值torch.max(x, y)tf.maximum(x, y)np.maximum(x, y)取两个tensor的最小值torch.min(x, y)torch.minimum(x, y)np.minmum(x, y)取最大值索引torch.max(x, dim)[1]tf.argmax(x, axis)np.argmax(x, axis)取最小值索引torch.min(x, dim)[1]tf.argmin(x, axis)np.argmin(x, axis)比较(x > y)torch.gt(x, y)tf.greater(x, y)np.greater(x, y)比较(x < y)torch.le(x, y)tf.less(x, y)np.less(x, y)比较(x==y)torch.eq(x, y)tf.equal(x, y)np.equal(x, y)比较(x!=y)torch.ne(x, y)tf.not_equal(x, y)np.not_queal(x , y)取符合条件值的索引torch.nonzero(cond)tf.where(cond)np.where(cond)多个tensor聚合torch.cat([x, y], dim)tf.concat([x,y], axis)np.concatenate([x,y], axis)堆叠成一个tensortorch.stack([x1, x2], dim)tf.stack([x1, x2], axis)np.stack([x, y], axis)tensor切成多个tensortorch.split(x1, split_size_or_sections, dim)tf.split(x1, num_or_size_splits, axis)np.split(x1, indices_or_sections, axis)`torch.unbind(x1, dim)tf.unstack(x1,axis)NULL随机扰乱torch.randperm(n) 1tf.random_shuffle(x)np.random.shuffle(x) 2 np.random.permutation(x ) 3前k个值torch.topk(x, n, sorted, dim)tf.nn.top_k(x, n, sorted)NULL

  1. 该方法只能对0~n-1自然数随机扰乱,所以先对索引随机扰乱 , 然后再根据扰乱后的索引取相应的数据得到扰乱后的数据
  2. 该方法会修改原值,没有返回值
  3. 该方法不会修改原值,返回扰乱后的值
【ndarray numpy和tensor(GPU上的numpy)速查】

    推荐阅读