I would like to implement a custom loss function, and I am using tensorflow with keras backend.
In my loss function for each training sample (2D matrix of size (2048x192) I would like to add a bandpassed version of the corresponding training sample as a constant (non-trainable) value.
I implemented bandpass filter based on How to implement band-pass Butterworth filter with Scipy.signal.butter
from scipy.signal import butter, sosfiltfilt
def butter_bandpass_sos(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
sos = butter(order, [low, high], analog=False, btype='band', output='sos')
return sos
def butter_bandpass_filter_sos(data, lowcut, highcut, fs, order=5):
sos = butter_bandpass_sos(lowcut, highcut, fs, order=order)
y = sosfiltfilt(sos, data)
return y
and for the loss function based on Adding a constant to Loss function in Tensorflow, I implemented:
from tensorflow.python.keras import backend
import tensorflow as tf
lowcut = 2.5e6
highcut = 7.5e6
order = 5
fs = 40e6
def HP_func(mat):
for i in range(0, 192):
RF_ch = mat[:, i]
y = butter_bandpass_filter_sos(RF_ch, lowcut, highcut, fs, order=order)
mat_band_sos[:, i] = y
return mat_band_sos
def my_custom_loss_HF(y_true,y_pred):
HF_mat = HP_func(y_true)
loss = backend.sqrt(tf.keras.losses.mean_squared_error(y_true, y_pred)) + HF_mat
return loss
I have three branches and therefore three losses:
model.compile(loss=['mean_squared_error', my_custom_loss_HF,'mean_squared_error'],
loss_weights=[1.0, 1.0, 1.0],
optimizer='Adam',
metrics=['mae', rmse])
but I am getting this error:
Traceback (most recent call last):
File "/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-19-01efc0717223>", line 6, in <module>
metrics=['mae', rmse])
File "/home/z003zpjj/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/z003zpjj/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 337, in compile
self._compile_weights_loss_and_weighted_metrics()
File "/home/z003zpjj/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/z003zpjj/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1710, in _compile_weights_loss_and_weighted_metrics
self.total_loss = self._prepare_total_loss(masks)
File "/home/z003zpjj/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1770, in _prepare_total_loss
per_sample_losses = loss_fn.call(y_true, y_pred)
File "/home/z003zpjj/venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py", line 215, in call
return self.fn(y_true, y_pred, **self._fn_kwargs)
File "<ipython-input-18-a4f1cf924d3f>", line 3, in my_custom_loss_HF
HF_mat = HP_fun(y_true)
File "<ipython-input-17-74a2f0e736b9>", line 19, in HP_fun
y = butter_bandpass_filter_sos(RF_ch, lowcut, highcut, fs, order=order)
File "<ipython-input-2-4e34aa35b4cd>", line 69, in butter_bandpass_filter_sos
y = sosfiltfilt(sos, data)
File "/home/z003zpjj/venv/lib/python3.6/site-packages/scipy/signal/signaltools.py", line 4131, in sosfiltfilt
x = _validate_x(x)
File "/home/z003zpjj/venv/lib/python3.6/site-packages/scipy/signal/signaltools.py", line 3926, in _validate_x
raise ValueError('x must be at least 1D')
ValueError: x must be at least 1D
How can I use the scipy function in my loss function?