当前位置: 首页 > news >正文

PyTorch之数据集随机值

这个可能需要新版pytorch 才可以做到o ~ 还没来及试

一个快捷的解决方案:

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

01 关于pytorch数据集随机种子的基本认识

在pytorch中random、torch.random等随机值产生方法一般没有问题,只有少数工人运行也可以保障其不同的最终值.

np.random.seed 会出现问题的原因是,当多处理采用 fork 方式产生子进程时,numpy 不会对不同的子进程产生不同的随机值.

换言之,当没有多处理使用时,numpy 不会出现随机种子的不同的问题;实验代码的可复现性要求一个是工人种子 ,即工人内包括numpy,random,torch.random所有的随机表现;另一个是Base ,即程序运行后的初始随机值,其可以通过以下两种方式产生

  1. torch.manual_seed(base_seed)

  2. 由特定的seed generator设置

generator = torch. Generator()
g.manual_seed(base_seed)
DataLoader(dataset, ..., generator=generator)

使用spawn模式可以斩断以上所有烦恼.

02 直接在网上搜这个问题会得到什么答案

参考很多的解决方案时,往往会提出以下功能:

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

让我们看看它的输出结果:
(第0,3列是索引,第1,4列是np.random的结果,第2,5列是random.randint的结果)

epoch 0
tensor([[    0,  5125, 13588,     0, 15905, 23182],
        [    1,  7204, 19825,     0, 13653, 25225]])
tensor([[    2,  1709, 11504,     0, 12842, 23238],
        [    3,  5715, 14058,     0, 15236, 28033]])
tensor([[    4,  1062, 11239,     0, 10142, 29869],
        [    5,  6574, 15672,     0, 19623, 25600]])
============================================================
epoch 1
tensor([[    0,  5125, 18134,     0, 15905, 28990],
        [    1,  7204, 13206,     0, 13653, 25106]])
tensor([[    2,  1709, 15512,     0, 12842, 29703],
        [    3,  5715, 14201,     0, 15236, 27696]])
tensor([[    4,  1062, 13994,     0, 10142, 23411],
        [    5,  6574, 18532,     0, 19623, 21744]])
============================================================

假设上述方案对一个时代内可以防止不同的工人出现随机值相同的情况,但不同的时代之间,其最终的随机种子仍然是不变的。

03 那应该如何解决

来自pytorch官方的解决方案:

https://github.com/pytorch/pytorch/pull/56488#issuecomment-825128350

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
 
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

来自numpy.random原作者的解决方案:

https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562

def worker_init_fn(id):
    process_seed = torch.initial_seed()
    # Back out the base_seed so we can use all the bits.
    base_seed = process_seed - id
    ss = np.random.SeedSequence([id, base_seed])
    # More than 128 bits (4 32-bit words) would be overkill.
    np.random.seed(ss.generate_state(4))
 
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

一个更简单但不保证正确性的解决方案:

def worker_init_fn(worker_id):
    np.random.seed((worker_id + torch.initial_seed()) % np.iinfo(np.int32).max)

ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

04 附上可运行的完整文件

import numpy as np
import random
import torch

# np.random.seed(0)

class Transform(object):
    def __init__(self):
        pass

    def __call__(self, item = None):
        return [np.random.randint(10000, 20000), random.randint(20000,30000)]

class RandomDataset(object):
    def __init__(self):
        pass

    def __getitem__(self, ind):
        item = [ind, np.random.randint(1, 10000), random.randint(10000, 20000), 0]
        tsfm =Transform()(item)
        return np.array(item + tsfm)
    def __len__(self):
        return 20

from torch.utils.data import DataLoader

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

ds = RandomDataset()
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

for epoch in range(2):
    print("epoch {}".format(epoch))
    np.random.seed()
    for batch in ds:
        print(batch)

完事拉 大伙试试把      whaosoft aiot http://143ai.com  

相关文章:

  • Python基础:【习题系列】多选题(一)
  • 目标检测——3D玩具数据集
  • Python 面向对象——6.封装
  • mac电脑搭建vue环境(上篇)
  • SOLIDWORKS Electrical 3D--精准的三维布线
  • C++ | Leetcode C++题解之第44题通配符匹配
  • 183896-00-6,Biotin-C3-PEG3-C3-NH2,可以选择性降解靶蛋白
  • Android 验证启动模式
  • Docker之数据卷
  • QT Mingw编译ffmpeg源码以及测试
  • Oracle中序列
  • [LeetCode]143.重排链表
  • #边学边记 必修4 高项:对事的管理 第六章 项目质量管理之质量控制
  • docker的网络模式
  • 集中供暖热计量温控一体化管理系统
  • mybatis案例--mapper代理开发
  • 【程序运行时的两种环境】
  • vue的相关概念
  • python驾到~障碍通通闪开,美女批量入内存~
  • 【UV打印机】RYPC打印软件教程(六)-系统维护
  • 数据库语句的基本
  • 【备战蓝桥杯 | 软件Java大学B组】十三届真题深刨详解(1)
  • 【MATLAB教程案例25】常用图像变换域的matlab仿真分析——DFT频域,DCT域,小波域等
  • VcXsrv XLaunch 闪退 failed to bind listener 的解决方法
  • 一些特殊SQL使用Mybatis的#{}和${}注意点
  • rpcs3模拟器配置要求是什么?
  • paddleNLP 安装
  • 【算法笔记】位运算详解
  • 《设计模式》装饰者模式
  • SpringBoot--在Entity(DAO)中使用枚举类型
  • Session(服务端会话跟踪技术)
  • CVPR2022 BatchFormer