Sparsity algorithm zeros weights in a non-structured way, so that zero values are randomly distributed inside the tensor. Most of the pruning algorithms set the less important weights to zero but the criteria of how they do it is different. The framework contains several implementations of pruning methods.
The magnitude pruning method implements a naive approach based on the assumption that weights with smaller absolute values contribute less to the model's performance and can therefore be removed. After each training epoch, the method calculates a threshold based on the current pruning ratio and uses it to zero out weights that fall below this threshold.
And here there are two options:
-
UNSTRUCTURED_MAGNITUDE_LOCAL: Unstructured magnitude-based pruning with local importance calculation. Weight importance is computed independently for each tensor. -
UNSTRUCTURED_MAGNITUDE_GLOBAL: Unstructured magnitude-based pruning with global importance calculation. Weight importance is computed across all tensors selected for pruning.
import nncf
...
pruned_model = nncf.prune(
model,
mode=nncf.PruneMode.UNSTRUCTURED_MAGNITUDE_GLOBAL,
ratio=0.7,
examples_inputs=example_input,
)To get a more accurate model, it is recommended to fine-tune the model for several epochs or use batch norm adaptation.
When using magnitude pruning without fine-tuning, it is recommended to perform Batch Norm adaptation after pruning to get more accurate results.
import nncf
...
def transform_fn(batch: tuple[torch.Tensor, int]) -> torch.Tensor:
inputs, _ = batch
return inputs.to(device=device)
calibration_dataset = nncf.Dataset(train_loader, transform_func=transform_fn)
pruned_model = nncf.batch_norm_adaptation(
pruned_model,
calibration_dataset=calibration_dataset,
num_iterations=200,
)The method is based on
We then reparametrize the network's weights as follows:
Here,
Here,
During training, we store and optimize
and reparametrize the sampling of
With this reparametrization, the probability of keeping a particular weight during the forward pass equals exactly
to
The method requires a long schedule of the training process in order to minimize the accuracy drop.
import nncf
from nncf.torch.function_hook.pruning.rb.losses import RBLoss
...
pruned_model = nncf.prune(
model,
mode=nncf.PruneMode.UNSTRUCTURED_REGULARIZATION_BASED,
examples_inputs=example_input,
)
num_epochs = 30
rb_loss = RBLoss(pruned_model, target_ratio=0.7, p=0.1).to(device)
...
for epoch in range(num_epochs):
for batch in train_loader:
...
outputs = pruned_model(inputs)
task_loss = criterion(outputs, targets)
reg_loss = rb_loss()
loss = task_loss + reg_lossTo gather statistics about the pruning process, use the following code:
stat = nncf.pruning_statistic(pruned_model)
print(stat)Note
Statistics about the pruning process cannot be gathered after using nncf.strip
The strip function is used to permanently apply the pruning masks to the model weights and to remove all auxiliary pruning-related operations.
After calling this function, the masks are merged into the weights, and any additional layers, parameters, or forward-pass operations introduced for pruning are removed. The resulting model contains only the pruned weights and can be used for inference without pruning overhead.
nncf.strip(pruned_model, strip_format=nncf.StripFormat.IN_PLACE)