Loading and Saving PyTorch Models Without Knowing the Architecture in Advance

PyTorch is a great tool to do deep learning research. However, when running large-scale experiments using various architectures, I always come across this one problem: How can I run the same experiments, evaluations or visualizations on models without knowing their architecture in advance? In this article, I want to present a simple approach allowing to load models without having to initialize the right architecture beforehand. The code of this article is available on GitHub.


There is one aspect of deep learning frameworks such as PyTorch or TensorFlow that I find very cumbersome for researcher: Usually, only the model parameters are written to file, but not the architecture details. This means that, before loading a model from file, the architecture needs to be initialized. As result, I have to touch every training, evaluation or visualization script whenever adding a new architecture or merely changing hyper-parameters of an architecture.

In this article, I want to present my approach to this problem. In particular, I want to be able to store and load any model, irrespective of its architecture and without knowing the architecture in advance. This allows to implement training, evaluation or visualization independent of network architectures — at least to a very large degree. The approach is based on storing the architecture alongside the model parameters. However, just serializing the whole model (architecture and parameters) will not work. Instead, all relevant hyper-parmeters of the architecture are saved with the parameters, so the architecture can automatically be initialized when loading the parameters. To avoid having to take care of this for each architecture individually, a "base architecture" enforces the required structure. Here, I focus on basic classification networks.

This is part of a series of articles; the previous articles are:

Later articles will cover adversarial example generation and adversarial training.

The code for this article is available on GitHub:

Code on GitHub

The Base Architecture

The base model, called Classifier in Listing 1, takes care of two things: First, the hyper-parameters that are common to all architectures as well as keyword arguments are stored as attributes. Second, all layers (considering feed-forward architectures only) are managed centrally in __layers. The second part is not necessarily needed in the following, but allows to have a common forward method across all architectures:

Listing 1: The base classifier stores all hyper-parameters including keyword arguments as attributes. These attributes are later stored alongside the model's parameter, allowing to initialize the architecture automatically.

class Classifier(torch.nn.Module):
    Simple classifier.

    # 1. The constructor takes the common hyper-parameters that are relevant to any classification network: the number of classes and the input resolution.
    # Any additional hyper-parameters are either keyword arguments or defined by derived classes.
    def __init__(self, N_class, resolution, **kwargs):
        super(Classifier, self).__init__()

        assert N_class > 0, 'positive N_class expected'
        assert len(resolution) <= 3

        # 1.1 The idea is that all "public" attributes, i.e., those not starting with _ or __ will be
        # stored alongside the model's parameters and used to initialize the architecture when loading.
        self.N_class = int(N_class)
        self.resolution = list(resolution)
        self.kwargs = kwargs

        # 1.2 This list will hold all the layers, we assume purely feed-forward models here.
        # Note that __ attributes are private, which is important for the State (later) to work properly.
        self.__layers = []

        # process any additional hyper-parameters that are common to all architectures, e.g., from kwargs.

    # 2. Function to add a layer to self.__layers.
    # Note that managing the involved layers centrally is not required for this approach to work.
    def append_layer(self, name, layer):
        setattr(self, name, layer)

    # 3. Forward function iterates over all layers.
    # The main goal of managing the layers centrally is to have a central forward function in order to potentially add
    # more functionality without having to touch individual architectures.
    def forward(self, image):
        output = image
        for name in self.__layers:
            output = getattr(self, name)(output)
        return output
  1. The constructor takes common hyper-parameters as arguments: the number of classes and the input resolution (assuming images). These, alongside keyword arguments, are stored as attributes. Note that these attributes can be assumed "public" as they do not have any underscore _ as prefix. The constructor also initializes as list of layers, __layers, that will manage all layers of a feed-forward network centrally.
  2. The append_layer method is used to add new layers. Any child class, i.e., the actual architectures, will use append_layer to setup the network. Again, note that this is not necessary for automatic loading using State in Listing 3 to work.
  3. The advantage of a central forward method is that additional functionality (e.g., saving intermediate activations, adding pre-processing) can be added across all architectures.

The key point of Listing 1 is that all hyper-parameters required for a specific architecture are stored as attributes and the setup of all required layers is supposed to happen within the constructor. As a result, saving the arguments to the constructor alongside the model's parameter will allow to initialize the architecture and load the parameters automatically. Note that the attribute names match the names of the constructor arguments. The following section will show an example network. However, it can be skipped in favor of the implementation for loading and saving.

SimpleNet Example

Listing 2 shows an example for the SimpleNet [] architecture. It is less known compared to, e.g., ResNets [] but allows to illustrate the idea in one short file as no shortcuts (i.e., residual connections) are needed:

The respository on GitHub includes implementations of ResNets [] and Wide ResNets [] in addition to SimpleNet.

Listing 2: .

class Flatten(torch.nn.Module):
    # ... flattens the input

def get_normalization2d(normalization, planes):
    # ... gets a BN or GN normalization layer

class SimpleNet(Classifier):
    def __init__(self, N_class, resolution=(1, 32, 32), activation='relu', dropout=False, normalization='bn', channels=64, **kwargs):
        assert resolution[1] >= 16 and resolution[2] >= 16
        assert activation in ['relu', 'sigmoid', 'tanh', 'leaky_relu', 'leaky_tanh', 'softsign']

        # 1. Call the base class to setup layers and attributes.
        super(SimpleNet, self).__init__(N_class, resolution, **kwargs)

        # 2. Save any additional hyper-parameters as attributes.
        self.activation = activation
        self.normalization = normalization
        self.dropout = dropout
        self.channels = channels

        # 3. Setup up multiple blocks of convolution, normalization, activation and potentially pooling.
        # block 1
        block = 1
        in_channels = resolution[0]
        out_channels = self.channels
        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=self.include_bias)
        torch.nn.init.kaiming_normal_(conv.weight, nonlinearity=activation)
        if self.include_bias:
            torch.nn.init.constant_(conv.bias, 0)
        self.append_layer('conv%d' % block, conv)
        self.append_layer('%s%d' % (self.normalization, block), get_normalization2d(self.normalization, out_channels))
        relu = activation_layer()
        self.append_layer('relu%d' % block, relu)

        # More blocks like block 1, including a couple that increase/decrease number of channels
        # and include pooling layers.
        # ... 

        # block 13
        block += 1
        in_channels = 2*self.channels
        out_channels = 2*self.channels
        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=self.include_bias)
        torch.nn.init.kaiming_normal_(conv.weight, nonlinearity=activation)
        if self.include_bias:
            torch.nn.init.constant_(conv.bias, 0)
        self.append_layer('conv%d' % block, conv)
        self.append_layer('%s%d' % (self.normalization, block), get_normalization2d(self.normalization, out_channels))
        relu = activation_layer()
        self.append_layer('relu%d' % block, relu)

        pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
        self.append_layer('pool%d' % block, pool)

        if self.dropout:
            drop = torch.nn.Dropout2d(p=0.1)
            self.append_layer('drop%d' % block, drop)

        view = Flatten()
        self.append_layer('view%d' % block, view)

        logits = torch.nn.Linear(out_channels, self.N_class, bias=self.include_bias)
        torch.nn.init.kaiming_normal_(logits.weight, nonlinearity=activation)
        if self.include_bias:
            torch.nn.init.constant_(logits.bias, 0)
        self.append_layer('logits', logits)

        # 4. Note that no forward method is required and all setup happens in the constructor.
  1. First, the base Classifier constructor is called to setup the required attributes.
  2. Any additional hyper-parameters are stored as attributes — the attributes' names match the constructor arguments' names.
  3. SimpleNet consists of several blocks of convolutional layer, batch or group normalization layer, and ReLU layer. Some blocks are followed by max pooling layers. These layers are created as usual in PyTorch but added using the self.append_layer method.
  4. Note that no dedicated forward method is necessary.

Note that the set of attributes (without __ or _ prefix) are, at the same time, the constructors' parameters and fully define the corresponding architectures. This means, we only need to store the number of classes, the resolution as well as SimpleNet's hyper-parameters (activation, dropout, normalization, and channels) to initialize the architecture.

Loading and Saving

Loading and saving models is encapsulated in the following State class with static functions load and checkpoint. Besides saving model hyper-parameters and parameters, it also allows to store the optimizer and scheduler used for training in order to easily resume training:

Listing 3: The State class allows to easily save and load models, potentially including an optimizer and learning rate scheduler. The key advantage is that it also saves the architecture's hyper-parameters to load the model without instantiating the architecture first.

class State:
    State of a model, including optional epoch and optimizer.

    # 1. A state consists of a model, optimizer, scheduler and an epoch.
    # This allows to resume training given the used optimizer, its internal state, and the scheduler.
    def __init__(self, model, optimizer=None, scheduler=None, epoch=None):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.epoch = epoch

    # 2. Saves the current state:
    # - The model class name is saved
    # - Parametrs, optimizer, scheduler and epoch are added.
    # - All attributes of the model class that do not start with _ or __ are saved (these are assumed to be the model's hyper-parameters).
    # - All keyword arguments of the model are saved as well.
    def save(self, filepath):
        model = self.model
        if not isinstance(model, dict):
            model = model.state_dict()
        model_class = self.model.__class__.__name__

        optimizer = self.optimizer
        if not isinstance(optimizer, dict) and optimizer is not None:
            optimizer = optimizer.state_dict()

        scheduler = self.scheduler
        if not isinstance(scheduler, dict) and scheduler is not None:
            scheduler = scheduler.state_dict()

        epoch = self.epoch
        assert utils.get_class('models', model_class) is not False
        arguments = dict((key, getattr(self.model, key)) for key in dir(self.model)
                         if not callable(getattr(self.model, key)) and not key.startswith('_') and not key == 'kwargs' and not key == 'T_destination')
        kwargs = getattr(self.model, 'kwargs', None)

        data = {'model': model, 'model_class': model_class,
                'optimizer': optimizer, 'scheduler': scheduler,
                'epoch': epoch, 'arguments': arguments, 'kwargs': kwargs}
        if version.parse('1.6.0') < version.parse(torch.__version__):
            torch.save(data, filepath, pickle_protocol=2, _use_new_zipfile_serialization=False)
            torch.save(data, filepath)

    # 3. Loading a model:
    # - The architecture is instantiated using the saved class name and the saved arguments and keyword arguments.
    # - A new state is created using the model (after loading paramteres), optimizer, scheduler and epoch.
    # Note that optimizer, scheduler and epoch may be None.
    def load(filepath):
        assert os.path.exists(filepath), 'file %s not found' % str(filepath)

        # https://discuss.pytorch.org/t/gpu-memory-usage-increases-by-90-after-torch-load/9213/3
        checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)

        model_class = utils.get_class('models', checkpoint['model_class'])
        if 'kwargs' in checkpoint:
            arguments = {**checkpoint['arguments'], **checkpoint['kwargs']}
            arguments = {**checkpoint['arguments']}
        model = model_class(**arguments)

        state = State(model, checkpoint['optimizer'], checkpoint['scheduler'], checkpoint['epoch'])

        del checkpoint

        return state

    def checkpoint(filepath, model, optimizer=None, scheduler=None, epoch=None):
        state = State(model, optimizer, scheduler, epoch)

The important parts of the State class are the load and save functions (i.e., point 2 and 3):

  1. A state consists of a model and, optionally, an optimizer, a scheduler and an epoch. The latter three ensure that training can be resumed easily.
  2. The save function writes the state to file. This is done as follows: First, the model class is saved. This allows to easily instantiate a new model when loading the file. Second, the model's arguments are stored — these are assumed to correspond to the model's attributes (not starting with _). This is why the hyper-parameters in Listing 2 are added as attributes. Third, the same is done for any potential keyword arguments. And foruth, the model's parameters as well as optimizer, scheduler and epoch are saved (the latter three may be None).
  3. Loading a state works as follows: After reading the model file, a new model is instantiated using the read model class, its arguments and keyword arguments. Then, the corresponding parameters are read.

Put together, the State class can be used as follows:

Listing 4: Example of using State from Listing 3 to save and load a SimpleNet.

N_class = 10
resolution = [3, 32, 32]
dropout = False
normalization = 'bn'

model = models.SimpleNet(N_class, resolution, dropout=dropout, channels=64, normalization=normalization)

common.state.State.checkpoint('model.pth.tar', model)
print('wrote model.pth.tar')

# note that we do not need to initialize the model again!
state = common.state.State.load('model.pth.tar')
print('read model.pth.tar')

model = state.model


In this article I showed how a simple, custom State class can be used to load stored PyTorch models without knowing the architecture to be loaded in advance. This is acheived by having a common base classifier class and saving all arguments of the architecture (in addition to keyword arguments) as attributes of the model. Instead of saving only parameters, the State class in Listing 3 also saves the model class and its attributes. When loading the state file, the architecture can be instantiated from the class name and using the attributes as arguments.

Next in this article series, I will talk about computing adversarial examples, training adversarially robust models using adversarial or confidence-calibrated adversarial training [] and evaluating robustness appropriately.

  • [] HasanPour, S. H., Rouhani, M., Fayyaz, M., and Sabokrou,M.Lets keep it simple, using simple architecturesto outperform deeper and more complex architectures.arXiv.org, abs/1608.06037, 2016.
  • [] He, K., Zhang, X., Ren, S., and Sun, J. Deep residuallearning for image recognition. InCVPR, 2016.
  • [] Zagoruyko, S. and Komodakis, N. Wide residual networks.InBMVC, 2016.
  • [] D. Stutz, M. Hein, B. Schiele. Confidence-Calibrated Adversarial Training: Generalizing to Unseen Attacks. ICML, 2020.
What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.