【小记】在 Docker 中多卡并行训练深度学习模型

type
Post
status
Published
summary
在深度学习训练中,训练方式根据并行策略的不同分为模型并行数据并行。 在 N 卡上使用 torch 进行数据并行训练时,又有两种不同的选择:torch.nn.DataParallel(DP)、torch.nn.DistributedDataParallel(DDP)。 今天的重点是在 Docker 中使用 torch.nn.DataParallel(DP) 进行单机多卡训练。
slug
llm-parallel-train
date
Apr 18, 2025
tags
PD
分布式训练
单机多卡
Docker
category
实践技巧
password
icon
URL
Property
Apr 18, 2025 02:05 PM
在深度学习训练中,训练方式根据并行策略的不同分为模型并行数据并行
  • 模型并行:模型并行主要应用于模型相比显存来说更大,一块 GPU 无法加载的场景,通过把模型切割为几个部分,分别加载到不同的 GPU 上,来进行训练
  • 数据并行:这个是日常会应用的比较多的情况。即每个 GPU 复制一份模型,将一批样本分为多份分发到各个GPU模型并行计算。因为求导以及加和都是线性的,数据并行在数学上也有效。采用数据并行相当于加大了batch_size,得到更准确的梯度或者加速训练。
在 N 卡上使用 torch 进行数据并行训练时,又有两种不同的选择:torch.nn.DataParallel(DP)、torch.nn.DistributedDataParallel(DDP),两者差异和对比详见:Pytorch速查
今天的重点是在 Docker 中使用 torch.nn.DataParallel(DP) 进行单机多卡训练。

Docker 环境准备

主要的重点就是在这,更详细一点其实是在创建容器的时候需要特别注意:
docker run -d --name test_container -v /host/path:/app/ --gpus all --shm-size=16g -it test_image bash
上面这个命令最主要的就是:
  • --gpus all:允许容器使用所有GPU;当然也可以设置指定 GPU:--gpus '"device=0,2"';使用 0 号和 2 号 GPU
  • --shm-size=16g:设置容器内的共享内存大小,多 GPU 卡之间进行通信和数据交换时需要用到,所以根据自己的实际情况设置大小,如果不设置则 Docker 默认分配 64M。这个共享内存一般不够用,不够用则会一直报 NCCL 错误
    • misc/shmutils.cc:72 NCCL WARN Error: failed to extend /dev/shm/nccl-xxx to 9637892 bytes
    • RuntimeError: NCCL Error 2: unhandled system error (run with NCCL_DEBUG=INFO for details)
  • 其他参数就是正常启动 Docker 容器常用的参数了,详情可以参考:Docker 学习笔记

代码调整

代码中主要是将模型对象包装为 DataParallel 模块,而如果是一个已经写好的库中导入的模型则可以通过继承的方式重写对应的方法来将模型调整为支持并行的模型。
我的调整方法如下,仅供参考。

原模型代码

class CTGANSynthesizerModel(MLSynthesizerModel, BatchedSynthesizer): """ Modified from ``sdgx.models.components.sdv_ctgan.synthesizers.ctgan.CTGANSynthesizer``. A CTGANSynthesizer but provided :ref:`SynthesizerModel` interface with chunked fit. This is the core class of the CTGAN project, where the different components are orchestrated together. For more details about the process, please check the [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503) paper. Args: embedding_dim (int): Size of the random sample passed to the Generator. Defaults to 128. generator_dim (tuple or list of ints): Size of the output samples for each one of the Residuals. A Residual Layer will be created for each one of the values provided. Defaults to (256, 256). discriminator_dim (tuple or list of ints): Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256). generator_lr (float): Learning rate for the generator. Defaults to 2e-4. generator_decay (float): Generator weight decay for the Adam Optimizer. Defaults to 1e-6. discriminator_lr (float): Learning rate for the discriminator. Defaults to 2e-4. discriminator_decay (float): Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. batch_size (int): Number of data samples to process in each step. discriminator_steps (int): Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation. log_frequency (boolean): Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. epochs (int): Number of training epochs. Defaults to 300. pac (int): Number of samples to group together when applying the discriminator. Defaults to 10. device (str): Device to run the training on. Preferred to be 'cuda' for GPU if available. """ MODEL_SAVE_NAME = "ctgan.pkl" def __init__( self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, epochs=300, pac=10, device="cuda" if torch.cuda.is_available() else "cpu", ): assert batch_size % 2 == 0 BatchedSynthesizer.__init__(self, batch_size=batch_size) self._embedding_dim = embedding_dim self._generator_dim = generator_dim self._discriminator_dim = discriminator_dim self._generator_lr = generator_lr self._generator_decay = generator_decay self._discriminator_lr = discriminator_lr self._discriminator_decay = discriminator_decay self._discriminator_steps = discriminator_steps self._log_frequency = log_frequency self._epochs = epochs self.pac = pac self._device = torch.device(device) # Following components are initialized in `_pre_fit` self._transformer: Optional[DataTransformer] = None self._data_sampler: Optional[DataSampler] = None self._generator = None self._ndarry_loader: Optional[NDArrayLoader] = None self.data_dim: Optional[int] = None def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, **kwargs): # In the future, sdgx use `sdgx.data_processor.transformers.discrete` to handle discrete_columns # the original sdv transformer will be removed in version 0.3.0 # This will be done in another PR. discrete_columns = list(metadata.get("discrete_columns")) if epochs is not None: self._epochs = epochs self._pre_fit(dataloader, discrete_columns, metadata) if self.fit_data_empty: logger.info("CTGAN fit finished because of empty df detected.") return logger.info("CTGAN prefit finished, start CTGAN training.") self._fit(len(self._ndarry_loader)) logger.info("CTGAN training finished.") def _pre_fit( self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None ): if not discrete_columns: discrete_columns = [] # self._validate_discrete_columns(dataloader.columns(), discrete_columns) discrete_columns = self._filter_discrete_columns(dataloader.columns(), discrete_columns) # if the df is empty, we don't need to do anything if self.fit_data_empty: return # Fit Transformer and DataSampler self._transformer = DataTransformer(metadata=metadata) logger.info("Fitting model's transformer...") self._transformer.fit(dataloader, discrete_columns) logger.info("Transforming data...") self._ndarry_loader = self._transformer.transform(dataloader) logger.info("Sampling data.") self._data_sampler = DataSampler( self._ndarry_loader, self._transformer.output_info_list, self._log_frequency ) logger.info("Initialize Generator.") # Initialize Generator self.data_dim = self._transformer.output_dimensions self._generator = Generator( self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, self.data_dim, ).to(self._device) @random_state def _fit(self, data_size: int): """Fit the CTGAN Synthesizer models to the training data.""" logger.info(f"Fit using data_size:{data_size}, data_dim: {self.data_dim}.") epochs = self._epochs # data_dim = self._transformer.output_dimensions discriminator = Discriminator( self.data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac, ).to(self._device) optimizerG = optim.Adam( self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), weight_decay=self._generator_decay, ) optimizerD = optim.Adam( discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay, ) mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) std = mean + 1 logger.info("Starting model training, epochs: {}".format(epochs)) steps_per_epoch = max(data_size // self._batch_size, 1) for i in range(epochs): start_time = time.time() for id_ in tqdm.tqdm(range(steps_per_epoch), desc="Fitting batches", delay=3): for n in range(self._discriminator_steps): fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None real = self._data_sampler.sample_data(self._batch_size, col, opt) else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) m1 = torch.from_numpy(m1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) perm = np.arange(self._batch_size) np.random.shuffle(perm) real = self._data_sampler.sample_data( self._batch_size, col[perm], opt[perm] ) c2 = c1[perm] fake = self._generator(fakez) fakeact = self._apply_activate(fake) real = torch.from_numpy(real.astype("float32")).to(self._device) if c1 is not None: fake_cat = torch.cat([fakeact, c1], dim=1) real_cat = torch.cat([real, c2], dim=1) else: real_cat = real fake_cat = fakeact y_fake = discriminator(fake_cat) y_real = discriminator(real_cat) pen = discriminator.calc_gradient_penalty( real_cat, fake_cat, self._device, self.pac ) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) optimizerD.zero_grad() pen.backward(retain_graph=True) loss_d.backward() optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) m1 = torch.from_numpy(m1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) fake = self._generator(fakez) fakeact = self._apply_activate(fake) if c1 is not None: y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) else: y_fake = discriminator(fakeact) if condvec is None: cross_entropy = 0 else: cross_entropy = self._cond_loss(fake, c1, m1) loss_g = -torch.mean(y_fake) + cross_entropy optimizerG.zero_grad() loss_g.backward() optimizerG.step() logger.info( f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," # noqa: T001 f" Loss D: {loss_d.detach().cpu(): .4f}," f" Time: {time.time() - start_time: .4f}", ) def sample(self, count: int, *args, **kwargs) -> pd.DataFrame: if self.fit_data_empty: return pd.DataFrame(index=range(count)) return self._sample(count, *args, **kwargs) @random_state def _sample(self, n, condition_column=None, condition_value=None, drop_more=True): """Sample data similar to the training data. Choosing a condition_column and condition_value will increase the probability of the discrete condition_value happening in the condition_column. Args: n (int): Number of rows to sample. condition_column (string): Name of a discrete column. condition_value (string): Name of the category in the condition_column which we wish to increase the probability of happening. Returns: numpy.ndarray or pandas.DataFrame """ if condition_column is not None and condition_value is not None: condition_info = self._transformer.convert_column_name_value_to_id( condition_column, condition_value ) global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info( condition_info, self._batch_size ) else: global_condition_vec = None steps = math.ceil(n / self._batch_size) data = [] for _ in tqdm.tqdm(range(steps), desc="Sampling batches", delay=3): mean = torch.zeros(self._batch_size, self._embedding_dim) std = mean + 1 fakez = torch.normal(mean=mean, std=std).to(self._device) if global_condition_vec is not None: condvec = global_condition_vec.copy() else: condvec = self._data_sampler.sample_original_condvec(self._batch_size) if condvec is None: pass else: c1 = condvec c1 = torch.from_numpy(c1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) fake = self._generator(fakez) fakeact = self._apply_activate(fake) data.append(fakeact.detach().cpu().numpy()) data = np.concatenate(data, axis=0) logger.info("CTGAN Generated {} raw samples.".format(data.shape[0])) if drop_more: data = data[:n] return self._transformer.inverse_transform(data) def save(self, save_dir: str | Path): save_dir.mkdir(parents=True, exist_ok=True) return SDVBaseSynthesizer.save(self, save_dir / self.MODEL_SAVE_NAME) @classmethod def load(cls, save_dir: str | Path, device: str = None) -> "CTGANSynthesizerModel": return SDVBaseSynthesizer.load(save_dir / cls.MODEL_SAVE_NAME, device) @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): """Deals with the instability of the gumbel_softmax for older versions of torch. For more details about the issue: https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing Args: logits […, num_features]: Unnormalized log probabilities tau: Non-negative scalar temperature hard (bool): If True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd dim (int): A dimension along which softmax will be computed. Default: -1. Returns: Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. """ if version.parse(torch.__version__) < version.parse("1.2.0"): for i in range(10): transformed = functional.gumbel_softmax( logits, tau=tau, hard=hard, eps=eps, dim=dim ) if not torch.isnan(transformed).any(): return transformed raise ValueError("gumbel_softmax returning NaN.") return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim) def _apply_activate(self, data): """Apply proper activation function to the output of the generator.""" data_t = [] st = 0 for column_info in self._transformer.output_info_list: for span_info in column_info: if span_info.activation_fn == "tanh": ed = st + span_info.dim data_t.append(torch.tanh(data[:, st:ed])) st = ed elif span_info.activation_fn == "softmax": ed = st + span_info.dim transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2) data_t.append(transformed) st = ed elif span_info.activation_fn == "linear": # for label encoder ed = st + span_info.dim transformed = data[:, st:ed].clone() data_t.append(transformed) st = ed else: raise ValueError(f"Unexpected activation function {span_info.activation_fn}.") return torch.cat(data_t, dim=1) def _cond_loss(self, data, c, m): """Compute the cross entropy loss on the fixed discrete column.""" loss = [] st = 0 st_c = 0 for column_info in self._transformer.output_info_list: for span_info in column_info: if len(column_info) != 1 or span_info.activation_fn != "softmax": # not discrete column st += span_info.dim else: ed = st + span_info.dim ed_c = st_c + span_info.dim tmp = functional.cross_entropy( data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction="none", ) loss.append(tmp) st = ed st_c = ed_c loss = torch.stack(loss, dim=1) # noqa: PD013 return (loss * m).sum() / data.size()[0] def _filter_discrete_columns(self, train_data: List[str], discrete_columns: List[str]): """ We filter PII Column here, which PII would only be discrete for now. As PII would be generating from PII Generator which not synthetic from model. Besides we need to figure it out when to stop model fitting: The original data consists entirely of discrete column data, and all of this discrete column data is PII. For `train_data`, there are three possibilities for the columns type. - train_data = valid_discrete + valid_continue - train_data = valid_continue - train_data = valid_discrete For `discrete_columns`, discrete_columns = invalid_discrete(PII) + valid_discrete Thus, valid_discrete = discrete_columns - invalid_discrete = discrete_columns - Set.intersection(train_data, discrete_columns) Thus, original_data_is_all_PII: discrete_columns is not empty & train_data is empty """ # Discrete_columns is empty - simple an empty list, but we need to continue fitting continue columns if len(discrete_columns) == 0: return discrete_columns # Discrete_columns is not empty - check if train_data is empty for stop model fitting if len(train_data) == 0: self.fit_data_empty = True return discrete_columns # Filter valid discrete columns invalid_columns = set(discrete_columns) - set(train_data) return set(discrete_columns) - set(invalid_columns) def _validate_discrete_columns(self, train_data, discrete_columns): """Check whether ``discrete_columns`` exists in ``train_data``. Args: train_data (numpy.ndarray or pandas.DataFrame or list): Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. discrete_columns (list-like): List of discrete columns to be used to generate the Conditional Vector. If ``train_data`` is a Numpy array, this list should contain the integer indices of the columns. Otherwise, if it is a ``pandas.DataFrame``, this list should contain the column names. """ if isinstance(train_data, pd.DataFrame): invalid_columns = set(discrete_columns) - set(train_data.columns) elif isinstance(train_data, np.ndarray): invalid_columns = [] for column in discrete_columns: if column < 0 or column >= train_data.shape[1]: invalid_columns.append(column) elif isinstance(train_data, list): invalid_columns = set(discrete_columns) - set(train_data) else: raise TypeError("``train_data`` should be either pd.DataFrame or np.array.") if invalid_columns: raise ValueError(f"Invalid columns found: {invalid_columns}") def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" self._device = device if self._generator is not None: self._generator.to(self._device)

调整后的模型代码

class DPParallelCTGANSynthesizerModel(CTGANSynthesizerModel): """ 并行 CTGAN 模型,使用 DataParallel 在单进程中利用多 GPU。 新增参数: gpu_ids (list): 要使用的 GPU ID 列表,例如 [0, 1, 2, 3] """ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, epochs=300, pac=10, device="cuda" if torch.cuda.is_available() else "cpu", # 新增并行参数 gpu_ids=None): # 调用父类构造方法 super().__init__(embedding_dim=embedding_dim, generator_dim=generator_dim, discriminator_dim=discriminator_dim, generator_lr=generator_lr, generator_decay=generator_decay, discriminator_lr=discriminator_lr, discriminator_decay=discriminator_decay, batch_size=batch_size, discriminator_steps=discriminator_steps, log_frequency=log_frequency, epochs=epochs, pac=pac, device=device ) self.gpu_ids = gpu_ids if gpu_ids is not None else list(range(torch.cuda.device_count())) self.is_parallel = len(self.gpu_ids) > 1 # 将原本在 _pre_fit 中动态创建的 discriminator,提前持久化为成员属性 self._discriminator = None self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def _pre_fit(self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata=None): """ 重写预处理过程: - 对 transformer、data sampler 初始化保持不变; - 生成器与判别器提前创建为持久化成员,并包装为 DataParallel 模块; """ if not discrete_columns: discrete_columns = [] discrete_columns = self._filter_discrete_columns(dataloader.columns(), discrete_columns) if self.fit_data_empty: return # 初始化 transformer,并转换数据 self._transformer = DataTransformer(metadata=metadata) logger.info("DP并行 CTGAN: Fitting model's transformer...") self._transformer.fit(dataloader, discrete_columns) logger.info("DP并行 CTGAN: Transforming data...") self._ndarry_loader = self._transformer.transform(dataloader) logger.info("DP并行 CTGAN: Sampling data.") self._data_sampler = DataSampler( self._ndarry_loader, self._transformer.output_info_list, self._log_frequency ) logger.info("DP并行 CTGAN: Initialize Generator and Discriminator.") self.data_dim = self._transformer.output_dimensions # 创建生成器,注意数据采样器需要计算条件向量维度 self._generator = Generator( self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, self.data_dim, ).to(self._device) # 持久化判别器创建,条件维度同样需要加上采样器的输出 self._discriminator = Discriminator( self.data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac, ).to(self._device) # 若有多个 GPU,则使用 DataParallel 包装模型 if self.is_parallel: logger.info(f"使用 DataParallel 在 GPU {self.gpu_ids} 上训练") self._generator = nn.DataParallel(self._generator, device_ids=self.gpu_ids) self._discriminator = nn.DataParallel(self._discriminator, device_ids=self.gpu_ids) @random_state def _fit(self, data_size: int): """ 修改训练流程,使用持久化的 generator 和 discriminator, 保证在每个 epoch 内不重复创建。 """ logger.info(f"DP并行 CTGAN: Fit using data_size:{data_size}, data_dim: {self.data_dim}.") epochs = self._epochs optimizerG = optim.Adam( self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), weight_decay=self._generator_decay, ) optimizerD = optim.Adam( self._discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay, ) mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) std = mean + 1 logger.info("DP并行 CTGAN: Starting model training, epochs: {}".format(epochs)) steps_per_epoch = max(data_size // self._batch_size, 1) for i in range(epochs): start_time = time.time() for id_ in tqdm.tqdm(range(steps_per_epoch), desc="Fitting batches", delay=3): # 多次更新判别器 for n in range(self._discriminator_steps): fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None real = self._data_sampler.sample_data(self._batch_size, col, opt) else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) m1 = torch.from_numpy(m1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) perm = np.arange(self._batch_size) np.random.shuffle(perm) real = self._data_sampler.sample_data( self._batch_size, np.array(col)[perm], np.array(opt)[perm] ) c2 = c1[perm] fake = self._generator(fakez) fakeact = self._apply_activate(fake) real = torch.from_numpy(real.astype("float32")).to(self._device) if c1 is not None: fake_cat = torch.cat([fakeact, c1], dim=1) real_cat = torch.cat([real, c2], dim=1) else: fake_cat = fakeact real_cat = real y_fake = self._discriminator(fake_cat) y_real = self._discriminator(real_cat) # 计算梯度惩罚 - 注意判断是否为 DataParallel if self.is_parallel: pen = self._discriminator.module.calc_gradient_penalty( real_cat, fake_cat, self._device, self.pac ) else: pen = self._discriminator.calc_gradient_penalty( real_cat, fake_cat, self._device, self.pac ) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) optimizerD.zero_grad() pen.backward(retain_graph=True) loss_d.backward() optimizerD.step() # 更新生成器 fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) m1 = torch.from_numpy(m1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) fake = self._generator(fakez) fakeact = self._apply_activate(fake) if c1 is not None: y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1)) else: y_fake = self._discriminator(fakeact) cross_entropy = 0 if condvec is None else self._cond_loss(fake, c1, m1) loss_g = -torch.mean(y_fake) + cross_entropy optimizerG.zero_grad() loss_g.backward() optimizerG.step() logger.info( f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," f" Loss D: {loss_d.detach().cpu(): .4f}," f" Time: {time.time() - start_time: .4f}" ) def save(self, save_dir: str): """保存模型""" os.makedirs(save_dir, exist_ok=True) # 如果是并行模型,需要保存 module 而不是整个 DataParallel if self.is_parallel: generator = self._generator.module discriminator = self._discriminator.module else: generator = self._generator discriminator = self._discriminator # 临时替换模型以便保存 temp_generator = self._generator temp_discriminator = self._discriminator self._generator = generator self._discriminator = discriminator # 调用父类的保存方法 result = super().save(save_dir) # 恢复原模型 self._generator = temp_generator self._discriminator = temp_discriminator return result @classmethod def load(cls, save_dir: str, device: str = None, gpu_ids=None): """ 从保存目录加载模型 """ model = super().load(save_dir, device=device) # 用加载后的模型构造并行版本 parallel_model = cls(gpu_ids=gpu_ids) # 这里可添加加载参数更新等操作 parallel_model.__dict__.update(model.__dict__) # 重新包装为 DataParallel 如果需要 if parallel_model.is_parallel: if hasattr(parallel_model, '_generator') and parallel_model._generator is not None: parallel_model._generator = nn.DataParallel( parallel_model._generator, device_ids=parallel_model.gpu_ids ) if hasattr(parallel_model, '_discriminator') and parallel_model._discriminator is not None: parallel_model._discriminator = nn.DataParallel( parallel_model._discriminator, device_ids=parallel_model.gpu_ids ) return parallel_model
If you have any questions, please contact me.