Introduction
There is one aspect of deep learning frameworks such as PyTorch or TensorFlow that I find very cumbersome for researcher: Usually, only the model parameters are written to file, but not the architecture details. This means that, before loading a model from file, the architecture needs to be initialized. As result, I have to touch every training, evaluation or visualization script whenever adding a new architecture or merely changing hyper-parameters of an architecture.
In this article, I want to present my approach to this problem. In particular, I want to be able to store and load any model, irrespective of its architecture and without knowing the architecture in advance. This allows to implement training, evaluation or visualization independent of network architectures — at least to a very large degree. The approach is based on storing the architecture alongside the model parameters. However, just serializing the whole model (architecture and parameters) will not work. Instead, all relevant hyper-parmeters of the architecture are saved with the parameters, so the architecture can automatically be initialized when loading the parameters. To avoid having to take care of this for each architecture individually, a "base architecture" enforces the required structure. Here, I focus on basic classification networks.
This is part of a series of articles; the previous articles are:
Later articles will cover adversarial example generation and adversarial training.
The code for this article is available on GitHub:
Code on GitHubThe Base Architecture
The base model, called Classifier
in Listing 1, takes care of two things: First, the hyper-parameters that are common to all architectures as well as keyword arguments are stored as attributes. Second, all layers (considering feed-forward architectures only) are managed centrally in __layers
. The second part is not necessarily needed in the following, but allows to have a common forward
method across all architectures:
Listing 1: The base classifier stores all hyper-parameters including keyword arguments as attributes. These attributes are later stored alongside the model's parameter, allowing to initialize the architecture automatically.
class Classifier(torch.nn.Module): """ Simple classifier. """ # 1. The constructor takes the common hyper-parameters that are relevant to any classification network: the number of classes and the input resolution. # Any additional hyper-parameters are either keyword arguments or defined by derived classes. def __init__(self, N_class, resolution, **kwargs): super(Classifier, self).__init__() assert N_class > 0, 'positive N_class expected' assert len(resolution) <= 3 # 1.1 The idea is that all "public" attributes, i.e., those not starting with _ or __ will be # stored alongside the model's parameters and used to initialize the architecture when loading. self.N_class = int(N_class) self.resolution = list(resolution) self.kwargs = kwargs # 1.2 This list will hold all the layers, we assume purely feed-forward models here. # Note that __ attributes are private, which is important for the State (later) to work properly. self.__layers = [] # process any additional hyper-parameters that are common to all architectures, e.g., from kwargs. # 2. Function to add a layer to self.__layers. # Note that managing the involved layers centrally is not required for this approach to work. def append_layer(self, name, layer): setattr(self, name, layer) self.__layers.append(name) # 3. Forward function iterates over all layers. # The main goal of managing the layers centrally is to have a central forward function in order to potentially add # more functionality without having to touch individual architectures. def forward(self, image): output = image for name in self.__layers: output = getattr(self, name)(output) return output
- The constructor takes common hyper-parameters as arguments: the number of classes and the input resolution (assuming images). These, alongside keyword arguments, are stored as attributes. Note that these attributes can be assumed "public" as they do not have any underscore
_
as prefix. The constructor also initializes as list of layers,__layers
, that will manage all layers of a feed-forward network centrally. - The
append_layer
method is used to add new layers. Any child class, i.e., the actual architectures, will useappend_layer
to setup the network. Again, note that this is not necessary for automatic loading usingState
in Listing 3 to work. - The advantage of a central
forward
method is that additional functionality (e.g., saving intermediate activations, adding pre-processing) can be added across all architectures.
The key point of Listing 1 is that all hyper-parameters required for a specific architecture are stored as attributes and the setup of all required layers is supposed to happen within the constructor. As a result, saving the arguments to the constructor alongside the model's parameter will allow to initialize the architecture and load the parameters automatically. Note that the attribute names match the names of the constructor arguments. The following section will show an example network. However, it can be skipped in favor of the implementation for loading and saving.
SimpleNet Example
Listing 2 shows an example for the SimpleNet [] architecture. It is less known compared to, e.g., ResNets [] but allows to illustrate the idea in one short file as no shortcuts (i.e., residual connections) are needed:
The respository on GitHub includes implementations of ResNets [] and Wide ResNets [] in addition to SimpleNet.
Listing 2: .
class Flatten(torch.nn.Module): # ... flattens the input def get_normalization2d(normalization, planes): # ... gets a BN or GN normalization layer class SimpleNet(Classifier): def __init__(self, N_class, resolution=(1, 32, 32), activation='relu', dropout=False, normalization='bn', channels=64, **kwargs): assert resolution[1] >= 16 and resolution[2] >= 16 assert activation in ['relu', 'sigmoid', 'tanh', 'leaky_relu', 'leaky_tanh', 'softsign'] # 1. Call the base class to setup layers and attributes. super(SimpleNet, self).__init__(N_class, resolution, **kwargs) # 2. Save any additional hyper-parameters as attributes. self.activation = activation self.normalization = normalization self.dropout = dropout self.channels = channels # 3. Setup up multiple blocks of convolution, normalization, activation and potentially pooling. # block 1 block = 1 in_channels = resolution[0] out_channels = self.channels conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=self.include_bias) torch.nn.init.kaiming_normal_(conv.weight, nonlinearity=activation) if self.include_bias: torch.nn.init.constant_(conv.bias, 0) self.append_layer('conv%d' % block, conv) self.append_layer('%s%d' % (self.normalization, block), get_normalization2d(self.normalization, out_channels)) relu = activation_layer() self.append_layer('relu%d' % block, relu) # More blocks like block 1, including a couple that increase/decrease number of channels # and include pooling layers. # ... # block 13 block += 1 in_channels = 2*self.channels out_channels = 2*self.channels conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=self.include_bias) torch.nn.init.kaiming_normal_(conv.weight, nonlinearity=activation) if self.include_bias: torch.nn.init.constant_(conv.bias, 0) self.append_layer('conv%d' % block, conv) self.append_layer('%s%d' % (self.normalization, block), get_normalization2d(self.normalization, out_channels)) relu = activation_layer() self.append_layer('relu%d' % block, relu) pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) self.append_layer('pool%d' % block, pool) if self.dropout: drop = torch.nn.Dropout2d(p=0.1) self.append_layer('drop%d' % block, drop) view = Flatten() self.append_layer('view%d' % block, view) logits = torch.nn.Linear(out_channels, self.N_class, bias=self.include_bias) torch.nn.init.kaiming_normal_(logits.weight, nonlinearity=activation) if self.include_bias: torch.nn.init.constant_(logits.bias, 0) self.append_layer('logits', logits) # 4. Note that no forward method is required and all setup happens in the constructor.
- First, the base
Classifier
constructor is called to setup the required attributes. - Any additional hyper-parameters are stored as attributes — the attributes' names match the constructor arguments' names.
- SimpleNet consists of several blocks of convolutional layer, batch or group normalization layer, and ReLU layer. Some blocks are followed by max pooling layers. These layers are created as usual in PyTorch but added using the
self.append_layer
method. - Note that no dedicated
forward
method is necessary.
Note that the set of attributes (without __
or _
prefix) are, at the same time, the constructors' parameters and fully define the corresponding architectures. This means, we only need to store the number of classes, the resolution as well as SimpleNet's hyper-parameters (activation
, dropout
, normalization
, and channels
) to initialize the architecture.
Loading and Saving
Loading and saving models is encapsulated in the following State
class with static functions load
and checkpoint
. Besides saving model hyper-parameters and parameters, it also allows to store the optimizer and scheduler used for training in order to easily resume training:
Listing 3: The State
class allows to easily save and load models, potentially including an optimizer and learning rate scheduler. The key advantage is that it also saves the architecture's hyper-parameters to load the model without instantiating the architecture first.
class State: """ State of a model, including optional epoch and optimizer. """ # 1. A state consists of a model, optimizer, scheduler and an epoch. # This allows to resume training given the used optimizer, its internal state, and the scheduler. def __init__(self, model, optimizer=None, scheduler=None, epoch=None): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.epoch = epoch # 2. Saves the current state: # - The model class name is saved # - Parametrs, optimizer, scheduler and epoch are added. # - All attributes of the model class that do not start with _ or __ are saved (these are assumed to be the model's hyper-parameters). # - All keyword arguments of the model are saved as well. def save(self, filepath): model = self.model if not isinstance(model, dict): model = model.state_dict() model_class = self.model.__class__.__name__ optimizer = self.optimizer if not isinstance(optimizer, dict) and optimizer is not None: optimizer = optimizer.state_dict() scheduler = self.scheduler if not isinstance(scheduler, dict) and scheduler is not None: scheduler = scheduler.state_dict() epoch = self.epoch assert utils.get_class('models', model_class) is not False arguments = dict((key, getattr(self.model, key)) for key in dir(self.model) if not callable(getattr(self.model, key)) and not key.startswith('_') and not key == 'kwargs' and not key == 'T_destination') kwargs = getattr(self.model, 'kwargs', None) utils.makedir(os.path.dirname(filepath)) data = {'model': model, 'model_class': model_class, 'optimizer': optimizer, 'scheduler': scheduler, 'epoch': epoch, 'arguments': arguments, 'kwargs': kwargs} if version.parse('1.6.0') < version.parse(torch.__version__): torch.save(data, filepath, pickle_protocol=2, _use_new_zipfile_serialization=False) else: torch.save(data, filepath) # 3. Loading a model: # - The architecture is instantiated using the saved class name and the saved arguments and keyword arguments. # - A new state is created using the model (after loading paramteres), optimizer, scheduler and epoch. # Note that optimizer, scheduler and epoch may be None. @staticmethod def load(filepath): assert os.path.exists(filepath), 'file %s not found' % str(filepath) # https://discuss.pytorch.org/t/gpu-memory-usage-increases-by-90-after-torch-load/9213/3 checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) model_class = utils.get_class('models', checkpoint['model_class']) if 'kwargs' in checkpoint: arguments = {**checkpoint['arguments'], **checkpoint['kwargs']} else: arguments = {**checkpoint['arguments']} model = model_class(**arguments) model.load_state_dict(checkpoint['model']) state = State(model, checkpoint['optimizer'], checkpoint['scheduler'], checkpoint['epoch']) del checkpoint torch.cuda.empty_cache() return state @staticmethod def checkpoint(filepath, model, optimizer=None, scheduler=None, epoch=None): state = State(model, optimizer, scheduler, epoch) state.save(filepath)
The important parts of the State
class are the load
and save
functions (i.e., point 2 and 3):
- A state consists of a model and, optionally, an optimizer, a scheduler and an epoch. The latter three ensure that training can be resumed easily.
- The
save
function writes the state to file. This is done as follows: First, the model class is saved. This allows to easily instantiate a new model when loading the file. Second, the model's arguments are stored — these are assumed to correspond to the model's attributes (not starting with_
). This is why the hyper-parameters in Listing 2 are added as attributes. Third, the same is done for any potential keyword arguments. And foruth, the model's parameters as well as optimizer, scheduler and epoch are saved (the latter three may be None). - Loading a state works as follows: After reading the model file, a new model is instantiated using the read model class, its arguments and keyword arguments. Then, the corresponding parameters are read.
Put together, the State
class can be used as follows:
Listing 4: Example of using State
from Listing 3 to save and load a SimpleNet.
N_class = 10 resolution = [3, 32, 32] dropout = False normalization = 'bn' model = models.SimpleNet(N_class, resolution, dropout=dropout, channels=64, normalization=normalization) common.state.State.checkpoint('model.pth.tar', model) print('wrote model.pth.tar') # note that we do not need to initialize the model again! state = common.state.State.load('model.pth.tar') print('read model.pth.tar') model = state.model print(model)
Conclusion
In this article I showed how a simple, custom State
class can be used to load stored PyTorch models without knowing the architecture to be loaded in advance. This is acheived by having a common base classifier class and saving all arguments of the architecture (in addition to keyword arguments) as attributes of the model. Instead of saving only parameters, the State
class in Listing 3 also saves the model class and its attributes. When loading the state file, the architecture can be instantiated from the class name and using the attributes as arguments.
Next in this article series, I will talk about computing adversarial examples, training adversarially robust models using adversarial or confidence-calibrated adversarial training [] and evaluating robustness appropriately.
- [] HasanPour, S. H., Rouhani, M., Fayyaz, M., and Sabokrou,M.Lets keep it simple, using simple architecturesto outperform deeper and more complex architectures.arXiv.org, abs/1608.06037, 2016.
- [] He, K., Zhang, X., Ren, S., and Sun, J. Deep residuallearning for image recognition. InCVPR, 2016.
- [] Zagoruyko, S. and Komodakis, N. Wide residual networks.InBMVC, 2016.
- [] D. Stutz, M. Hein, B. Schiele. Confidence-Calibrated Adversarial Training: Generalizing to Unseen Attacks. ICML, 2020.