This repository was archived by the owner on Mar 17, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 402
Expand file tree
/
Copy pathchannel_sparse_convolution.py
More file actions
executable file
·415 lines (362 loc) · 16.7 KB
/
channel_sparse_convolution.py
File metadata and controls
executable file
·415 lines (362 loc) · 16.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.training import moving_averages
import niftynet.layer.bn
import niftynet.layer.convolution
import niftynet.layer.deconvolution
from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.deconvolution import infer_output_dims
SUPPORTED_OP = {'2D': tf.nn.conv2d_transpose,
'3D': tf.nn.conv3d_transpose}
class ChannelSparseDeconvLayer(niftynet.layer.deconvolution.DeconvLayer):
"""
Channel sparse convolutions perform convolutions over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self, *args, **kwargs):
super(ChannelSparseDeconvLayer, self).__init__(*args, **kwargs)
def layer_op(self, input_tensor, input_mask=None, output_mask=None):
"""
:param input_tensor: image to convolve with kernel
:param input_mask: 1-Tensor with a binary mask of input channels to use
If this is None, all channels are used.
:param output_mask: 1-Tensor with a binary mask of output channels to
generate. If this is None, all channels are used and the number
of output channels is set at graph-creation time.
:return:
"""
input_shape = input_tensor.shape.as_list()
if input_mask is None:
_input_mask = tf.ones([input_shape[-1]]) > 0
else:
_input_mask = input_mask
if output_mask is None:
n_sparse_output_chns = self.n_output_chns
_output_mask = tf.ones([self.n_output_chns]) > 0
else:
n_sparse_output_chns = tf.reduce_sum(
tf.cast(output_mask, tf.float32))
_output_mask = output_mask
n_full_input_chns = _input_mask.shape.as_list()[0]
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
# initialize conv kernels/strides and then apply
w_full_size = np.vstack((
[self.kernel_size] * spatial_rank,
self.n_output_chns, n_full_input_chns)).flatten()
full_stride = np.vstack((
1, [self.stride] * spatial_rank, 1)).flatten()
deconv_kernel = tf.get_variable(
'w', shape=w_full_size.tolist(),
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
if spatial_rank == 2:
op_ = SUPPORTED_OP['2D']
elif spatial_rank == 3:
op_ = SUPPORTED_OP['3D']
else:
raise ValueError(
"Only 2D and 3D spatial deconvolutions are supported")
output_dim = infer_output_dims(input_shape[1],
self.stride,
self.kernel_size,
self.padding)
sparse_output_size = \
[input_shape[0], [output_dim] * spatial_rank, n_sparse_output_chns]
sparse_output_size = tf.stack(sparse_output_size, 0)
output_tensor = op_(value=input_tensor,
filter=deconv_kernel,
output_shape=sparse_output_size,
strides=full_stride.tolist(),
padding=self.padding,
name='deconv')
if output_mask is None:
# If all output channels are used, we can specify
# the number of output channels which is useful for later layers
old_shape = output_tensor.shape.as_list()
old_shape[-1] = self.n_output_chns
output_tensor.set_shape(old_shape)
if not self.with_bias:
return output_tensor
# adding the bias term
bias_full_size = (self.n_output_chns,)
bias_term = tf.get_variable(
'b', shape=bias_full_size,
initializer=self.initializers['b'],
regularizer=self.regularizers['b'])
sparse_bias = tf.boolean_mask(bias_term, _output_mask)
output_tensor = tf.nn.bias_add(
output_tensor, sparse_bias, name='add_bias')
return output_tensor
class ChannelSparseConvLayer(niftynet.layer.convolution.ConvLayer):
"""
Channel sparse convolutions perform convolutions over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations.
"""
def __init__(self, *args, **kwargs):
super(ChannelSparseConvLayer, self).__init__(*args, **kwargs)
def layer_op(self, input_tensor, input_mask, output_mask):
"""
:param input_tensor: image to convolve with kernel
:param input_mask: 1-Tensor with a binary mask of input channels to use
If this is None, all channels are used.
:param output_mask: 1-Tensor with a binary mask of output channels to
generate. If this is None, all channels are used and
the number of output channels is set at graph-creation time.
:return:
"""
sparse_input_shape = input_tensor.shape.as_list()
if input_mask is None:
_input_mask = tf.ones([sparse_input_shape[-1]]) > 0
else:
_input_mask = input_mask
if output_mask is None:
_output_mask = tf.ones([self.n_output_chns]) > 0
else:
_output_mask = output_mask
n_full_input_chns = _input_mask.shape.as_list()[0]
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
# initialize conv kernels/strides and then apply
w_full_size = layer_util.expand_spatial_params(
self.kernel_size, spatial_rank)
# expand kernel size to include number of features
w_full_size = w_full_size + (n_full_input_chns, self.n_output_chns)
full_stride = layer_util.expand_spatial_params(
self.stride, spatial_rank)
full_dilation = layer_util.expand_spatial_params(
self.dilation, spatial_rank)
conv_kernel = tf.get_variable(
'w', shape=w_full_size,
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
if spatial_rank == 2:
transpositions = [[3, 2, 1, 0], [1, 0, 2, 3], [3, 2, 0, 1]]
elif spatial_rank == 3:
transpositions = [[4, 3, 2, 1, 0], [1, 0, 2, 3, 4], [4, 3, 2, 0, 1]]
else:
raise NotImplementedError("spatial rank not supported")
sparse_kernel = tf.transpose(conv_kernel, transpositions[0])
sparse_kernel = tf.boolean_mask(sparse_kernel, _output_mask)
sparse_kernel = tf.transpose(sparse_kernel, transpositions[1])
sparse_kernel = tf.boolean_mask(sparse_kernel, _input_mask)
sparse_kernel = tf.transpose(sparse_kernel, transpositions[2])
output_tensor = tf.nn.convolution(input=input_tensor,
filter=sparse_kernel,
strides=full_stride,
dilation_rate=full_dilation,
padding=self.padding,
name='conv')
if output_mask is None:
# If all output channels are used, we can specify
# the number of output channels which is useful for later layers
old_shape = output_tensor.shape.as_list()
old_shape[-1] = self.n_output_chns
output_tensor.set_shape(old_shape)
if not self.with_bias:
return output_tensor
# adding the bias term
bias_term = tf.get_variable(
'b', shape=self.n_output_chns,
initializer=self.initializers['b'],
regularizer=self.regularizers['b'])
sparse_bias = tf.boolean_mask(bias_term, output_mask)
output_tensor = tf.nn.bias_add(
output_tensor, sparse_bias, name='add_bias')
return output_tensor
class ChannelSparseBNLayer(niftynet.layer.bn.BNLayer):
"""
Channel sparse convolutions perform convolutions over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self, n_dense_channels, *args, **kwargs):
self.n_dense_channels = n_dense_channels
super(ChannelSparseBNLayer, self).__init__(*args, **kwargs)
def layer_op(self, inputs, is_training, mask, use_local_stats=False):
"""
:param inputs: image to normalize. This typically represents a sparse
subset of channels from a sparse convolution.
:param is_training: boolean that is True during training.
When True, the layer uses batch statistics for normalization and
records a moving average of means and variances.
When False, the layer uses previously computed moving averages
for normalization.
:param mask: 1-Tensor with a binary mask identifying the sparse
channels represented in inputs
:param use_local_stats:
:return:
"""
if mask is None:
mask = tf.ones([self.n_dense_channels]) > 0
else:
mask = mask
input_shape = inputs.shape
mask_shape = mask.shape
# operates on all dims except the last dim
params_shape = mask_shape[-1:]
assert params_shape[0] == self.n_dense_channels, \
'Mask size {} must match n_dense_channels {}.'.format(
params_shape[0], self.n_dense_channels)
axes = list(range(input_shape.ndims - 1))
# create trainable variables and moving average variables
beta = tf.get_variable(
'beta',
shape=params_shape,
initializer=self.initializers['beta'],
regularizer=self.regularizers['beta'],
dtype=tf.float32, trainable=True)
gamma = tf.get_variable(
'gamma',
shape=params_shape,
initializer=self.initializers['gamma'],
regularizer=self.regularizers['gamma'],
dtype=tf.float32, trainable=True)
beta = tf.boolean_mask(beta, mask)
gamma = tf.boolean_mask(gamma, mask)
collections = [tf.GraphKeys.GLOBAL_VARIABLES]
moving_mean = tf.get_variable(
'moving_mean',
shape=params_shape,
initializer=self.initializers['moving_mean'],
dtype=tf.float32, trainable=False, collections=collections)
moving_variance = tf.get_variable(
'moving_variance',
shape=params_shape,
initializer=self.initializers['moving_variance'],
dtype=tf.float32, trainable=False, collections=collections)
# mean and var
mean, variance = tf.nn.moments(inputs, axes)
# only update masked moving averages
mean_update = tf.dynamic_stitch(
[tf.to_int32(tf.where(mask)[:, 0]),
tf.to_int32(tf.where(~mask)[:, 0])],
[mean,
tf.boolean_mask(moving_mean, ~mask)])
variance_update = tf.dynamic_stitch(
[tf.to_int32(tf.where(mask)[:, 0]),
tf.to_int32(tf.where(~mask)[:, 0])],
[variance,
tf.boolean_mask(moving_variance, ~mask)])
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean_update, self.moving_decay).op
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance_update, self.moving_decay).op
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance)
# call the normalisation function
if is_training or use_local_stats:
outputs = tf.nn.batch_normalization(
inputs, mean, variance,
beta, gamma, self.eps, name='batch_norm')
else:
outputs = tf.nn.batch_normalization(
inputs,
tf.boolean_mask(moving_mean, mask),
tf.boolean_mask(moving_variance, mask),
beta, gamma, self.eps, name='batch_norm')
outputs.set_shape(inputs.get_shape())
return outputs
class ChannelSparseConvolutionalLayer(TrainableLayer):
"""
This class defines a composite layer with optional components::
channel sparse convolution ->
batchwise-spatial dropout ->
batch_norm ->
activation
The b_initializer and b_regularizer are applied to
the ChannelSparseConvLayer, the w_initializer and w_regularizer
are applied to the ChannelSparseConvLayer, the batch normalisation
layer, and the activation layer (for 'prelu')
"""
def __init__(self,
n_output_chns,
kernel_size=3,
stride=1,
dilation=1,
padding='SAME',
with_bias=False,
feature_normalization='batch',
acti_func=None,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
moving_decay=0.9,
eps=1e-5,
name="conv"):
self.acti_func = acti_func
self.feature_normalization = feature_normalization
self.layer_name = '{}'.format(name)
if self.feature_normalization == 'batch':
self.layer_name += '_bn'
if self.acti_func is not None:
self.layer_name += '_{}'.format(self.acti_func)
super(ChannelSparseConvolutionalLayer, self).__init__(
name=self.layer_name)
# for ConvLayer
self.n_output_chns = n_output_chns
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.with_bias = with_bias
# for BNLayer
self.moving_decay = moving_decay
self.eps = eps
self.initializers = {
'w': w_initializer if w_initializer else
niftynet.layer.convolution.default_w_initializer(),
'b': b_initializer if b_initializer else
niftynet.layer.convolution.default_b_initializer()}
self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
def layer_op(self,
input_tensor,
input_mask=None,
is_training=None,
keep_prob=None):
conv_layer = ChannelSparseConvLayer(
n_output_chns=self.n_output_chns,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=self.padding,
with_bias=self.with_bias,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
b_initializer=self.initializers['b'],
b_regularizer=self.regularizers['b'],
name='conv_')
if keep_prob is not None:
output_mask = \
tf.to_float(tf.random_shuffle(tf.range(self.n_output_chns))) \
< keep_prob * self.n_output_chns
n_output_ch = math.ceil(keep_prob * self.n_output_chns)
else:
output_mask = tf.ones([self.n_output_chns]) > 0
n_output_ch = self.n_output_chns
output_tensor = conv_layer(input_tensor, input_mask, output_mask)
output_tensor.set_shape(
output_tensor.shape.as_list()[:-1] + [n_output_ch])
if self.feature_normalization == 'batch':
if is_training is None:
raise ValueError('For batch norm, you must set the `is_training` argument.')
bn_layer = ChannelSparseBNLayer(
self.n_output_chns,
regularizer=self.regularizers['w'],
moving_decay=self.moving_decay,
eps=self.eps,
name='bn_')
output_tensor = bn_layer(output_tensor, is_training, output_mask)
if self.acti_func is not None:
acti_layer = ActiLayer(
func=self.acti_func,
regularizer=self.regularizers['w'],
name='acti_')
output_tensor = acti_layer(output_tensor)
return output_tensor, output_mask