Well, I defined a custom dimension weight and then repeated it before the convolution.
import torch
from typing import Optional, List, Tuple, Union
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple
from torch.nn import init
import math
from torch.nn import functional as F
class _ConvNd(torch.nn.Module):
__constants__ = ['stride', 'padding', 'dilation', 'groups',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size']
__annotations__ = {'bias': Optional[torch.Tensor]}
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
...
_in_channels: int
_reversed_padding_repeated_twice: List[int]
out_channels: int
kernel_size: Tuple[int, ...]
stride: Tuple[int, ...]
padding: Union[str, Tuple[int, ...]]
dilation: Tuple[int, ...]
transposed: bool
output_padding: Tuple[int, ...]
groups: int
padding_mode: str
weight: Tensor
bias: Optional[Tensor]
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
dilation: Tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
valid_padding_strings = {'same', 'valid'}
if isinstance(padding, str):
if padding not in valid_padding_strings:
raise ValueError(
"Invalid padding string {!r}, should be one of {}".format(
padding, valid_padding_strings))
if padding == 'same' and any(s != 1 for s in stride):
raise ValueError("padding='same' is not supported for strided convolutions")
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in valid_padding_modes:
raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
self.padding_mode = padding_mode
# `_reversed_padding_repeated_twice` is the padding to be passed to
# `F.pad` if needed (e.g., for non-zero padding types that are
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
# reverse order than the dimension.
if isinstance(self.padding, str):
self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
if padding == 'same':
for d, k, i in zip(dilation, kernel_size,
range(len(kernel_size) - 1, -1, -1)):
total_padding = d * (k - 1)
left_pad = total_padding // 2
self._reversed_padding_repeated_twice[2 * i] = left_pad
self._reversed_padding_repeated_twice[2 * i + 1] = (
total_padding - left_pad)
else:
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
if transposed:
self.weight = Parameter(torch.empty(
(in_channels, out_channels // groups, *kernel_size), **factory_kwargs))
else:
self.weight = Parameter(torch.empty(
(groups, 1, *kernel_size), **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
# For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
return s.format(**self.__dict__)
def __setstate__(self, state):
super(_ConvNd, self).__setstate__(state)
if not hasattr(self, 'padding_mode'):
self.padding_mode = 'zeros'
class SharedConv2d(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super(SharedConv2d, self).__init__(
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight.repeat(self.out_channels//self.groups, self.in_channels//self.groups, 1, 1), bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight.repeat(self.out_channels//self.groups, self.in_channels//self.groups, 1, 1), bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight, self.bias)
Then we can use it like this:
SharedConv2d(kernel_size=(1,20), stride=1, groups=5, out_channels=30, in_channels=30, bias=False)
The question is, are gradients going to backpropagate properly with repeat
used on the weight?