I created a custom pooling layer using tensorflow layer subclassing. Here is the code:
class Min_Max_Pooling(tf.keras.layers.Layer):
def __init__(self, filter_size):
super(Min_Max_Pooling, self).__init__()
self.filter_size = filter_size
def call(self, inputs):
print('------------------------------------------------------')
print(f'inputs shape = {inputs.shape[-1]}')
print(f'filter_size = {self.filter_size}')
result = tf.zeros((int((inputs.shape[-1]/self.filter_size)*2)))
num_splits_length = int(inputs.shape[-1]/self.filter_size)
print(f'num_splits_length = {num_splits_length}')
print(f'result_length = {result.shape[-1]}')
split_sequence = tf.split(inputs, num_or_size_splits=num_splits_length, axis=-1)
count = 0
index = 0
def cond(count,result,index):
return tf.less(count,len(split_sequence))
def body(count,result,index):
# print(tf.gather(split_sequence, count))
max = tf.reduce_max(tf.gather(split_sequence, count))
min = tf.reduce_min(tf.gather(split_sequence, count))
index_max = tf.argmax(tf.gather(split_sequence, count))
index_min = tf.argmin(tf.gather(split_sequence, count))
indices = [[index], [index+1]]
if tf.cond(tf.greater(index_max , index_min), lambda: tf.constant(True), lambda: tf.constant(False)):
updates = [min,max]
else:
updates = [max,min]
result = tf.tensor_scatter_nd_update(result, indices, updates)
index +=2
count+=1
# print(f'result_shape = {result.shape[-1]}')
return [count,result,index]
output = tf.while_loop(cond, body, [count,result,index])[1]
return output
The motive of this layer is to downsample a timeseries something like:
The layer accepts a filter size and a tensor(timeseries) then it splits the tensor in to chunks according to filter size and then it loops over them calculating the min and max values of chunk and updates it to 'result' tensor using 'tf.tensor_scatter_nd_update' method according to there index (if max value comes before the min value then it first appends the max value and then the min value so as to not disturb the sequence of time series).
I the created the model using keras functional api here is the code:
input_layer = tf.keras.layers.Input(shape=(1000), name="input_layer")
layer_1 = Min_Max_Pooling(filter_size=4)(input_layer)
model = tf.keras.models.Model(input_layer, layer_1, name="model")
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), loss="categorical_crossentropy", run_eagerly=True)
print(model.summary())
After compiling the model i did model.predict here are the results:
data = pd.read_csv('/content/drive/MyDrive/stock.csv', parse_dates=False,
index_col=1)
tensor = data.close.head(1000).to_numpy()
tensor = tensor / max(tensor)
tensor = tf.convert_to_tensor(tensor)
# print(tensor)
# print(model.summary())
# tensor = tf.reshape(tensor, (1000))
result = model.predict(tensor)
output :
------------------------------------------------------
inputs shape = 1000
filter_size = 4
num_splits_length = 250
result_length = 500
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 1000)] 0
min__max__pooling_16 (Min_M (500,) 0
ax_Pooling)
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
None
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 32
filter_size = 4
num_splits_length = 8
result_length = 16
------------------------------------------------------
inputs shape = 8
filter_size = 4
num_splits_length = 2
result_length = 4
my problem is why is input shape changing from 1000 to 32 and why is the function called multiple times with parts of input of shape 32 and also if you change the filter size to 5 it results in to error as 32 is not completely divisible by 5. I don't understand what is the problem here. Dose anyone know how to solve this problem.
Here is the full code to reproduce the error :
import tensorflow as tf
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [25, 5]
class Min_Max_Pooling(tf.keras.layers.Layer):
def __init__(self, filter_size):
super(Min_Max_Pooling, self).__init__()
self.filter_size = filter_size
def call(self, inputs):
print('------------------------------------------------------')
# print(f'inputs = {inputs}')
print(f'inputs shape = {inputs.shape[-1]}')
print(f'filter_size = {self.filter_size}')
# print(f'remainder = {int(inputs.shape[-1]%self.filter_size)}')
result = tf.zeros((int((inputs.shape[-1]/self.filter_size)*2)))
num_splits_length = int(inputs.shape[-1]/self.filter_size)
print(f'num_splits_length = {num_splits_length}')
print(f'result_length = {result.shape[-1]}')
split_sequence = tf.split(inputs, num_or_size_splits=num_splits_length, axis=-1)
count = 0
index = 0
def cond(count,result,index):
return tf.less(count,len(split_sequence))
def body(count,result,index):
# print(tf.gather(split_sequence, count))
max = tf.reduce_max(tf.gather(split_sequence, count))
min = tf.reduce_min(tf.gather(split_sequence, count))
index_max = tf.argmax(tf.gather(split_sequence, count))
index_min = tf.argmin(tf.gather(split_sequence, count))
indices = [[index], [index+1]]
if tf.cond(tf.greater(index_max , index_min), lambda: tf.constant(True), lambda: tf.constant(False)):
updates = [min,max]
else:
updates = [max,min]
result = tf.tensor_scatter_nd_update(result, indices, updates)
index +=2
count+=1
# print(f'result_shape = {result.shape[-1]}')
return [count,result,index]
output = tf.while_loop(cond, body, [count,result,index])[1]
return output
input_layer = tf.keras.layers.Input(shape=(1000), name="input_layer")
lambda_layer_1 = Min_Max_Pooling(filter_size=4)(input_layer)
model = tf.keras.models.Model(input_layer, lambda_layer_1, name="model")
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), loss="categorical_crossentropy", run_eagerly=True)
print(model.summary())
data = pd.read_csv('/content/drive/MyDrive/ADANIPORTS.csv', parse_dates=False,
index_col=1)
tensor = data.close.head(1000).to_numpy()
tensor = tensor / max(tensor)
tensor = tf.convert_to_tensor(tensor)
print(tensor)
print(model.summary())
tensor = tf.reshape(tensor, (1000))
result = model.predict(tensor)
plt.plot(tensor)
plt.show()
plt.plot(result)
plt.show()
as i can't upload the stock file i will put the code to generate the tensor here or you can use any timeseries data.
a = [0.9921363 , 0.97597204, 0.9708752 , 0.9781564 , 0.97917577,
0.98733071, 0.98689384, 0.98383574, 0.98470948, 0.98034076,
0.97771953, 0.97655454, 0.97742828, 0.97728266, 0.97670016,
0.97830202, 0.98441823, 0.98339886, 0.99009757, 0.9927188 ,
0.98893258, 0.9860201 , 0.98733071, 0.98645697, 0.98514635,
0.98500073, 0.98529198, 0.98616572, 0.98529198, 0.98354449,
0.98354449, 0.97975826, 0.97801078, 0.97946702, 0.97946702,
0.98034076, 0.97990389, 0.97961264, 0.97786515, 0.97830202,
0.97859327, 0.9781564 , 0.97903014, 0.98150575, 0.98208825,
0.98354449, 0.98150575, 0.98019514, 0.97859327, 0.97801078,
0.98004951, 0.98136013, 0.981797 , 0.98019514, 0.98063201,
0.98077763, 0.97946702, 0.98150575, 0.97961264, 0.97975826,
0.97975826, 0.98063201, 0.97932139, 0.97801078, 0.97655454,
0.97626329, 0.9750983 , 0.97713703, 0.981797 , 0.98106888,
0.98034076, 0.97830202, 0.97844765, 0.97786515, 0.97247706,
0.97291394, 0.96971021, 0.96941896, 0.97000146, 0.97364206,
0.96869084, 0.96956458, 0.96854522, 0.96490462, 0.96577836,
0.96708898, 0.9672346 , 0.96403087, 0.96344838, 0.96315713,
0.963594 , 0.96373962, 0.96432212, 0.96388525, 0.96257463,
0.96140964, 0.96490462, 0.96330275, 0.96621523, 0.96592398,
0.96403087, 0.96184651, 0.96111839, 0.96082714, 0.96199214,
0.96242901, 0.96257463, 0.96272026, 0.96432212, 0.96432212,
0.96330275, 0.96315713, 0.96228338, 0.96082714, 0.96111839,
0.96140964, 0.96068152, 0.95733217, 0.95893403, 0.95485656,
0.95602155, 0.95514781, 0.95602155, 0.95514781, 0.95645842,
0.95747779, 0.95602155, 0.95602155, 0.9568953 , 0.95791466,
0.96111839, 0.95820591, 0.95776904, 0.95849716, 0.95806029,
0.9593709 , 0.95878841, 0.96432212, 0.96140964, 0.9593709 ,
0.95922528, 0.95864278, 0.95922528, 0.95878841, 0.95966215,
0.95776904, 0.95485656, 0.95136158, 0.9532547 , 0.95529343,
0.95034222, 0.94946847, 0.9532547 , 0.95529343, 0.96606961,
0.96126402, 0.96257463, 0.96039027, 0.95820591, 0.96009902,
0.95791466, 0.95340032, 0.95063346, 0.9484491 , 0.95310907,
0.95296345, 0.95427406, 0.94990534, 0.94859473, 0.95121596,
0.95063346, 0.95558468, 0.95645842, 0.96039027, 0.96636086,
0.96461337, 0.96184651, 0.96563274, 0.96286588, 0.96170089,
0.96330275, 0.9599534 , 0.96068152, 0.95820591, 0.95485656,
0.95718654, 0.95602155, 0.96009902, 0.9593709 , 0.95747779,
0.9557303 , 0.95762342, 0.95558468, 0.95529343, 0.95558468,
0.95616718, 0.95529343, 0.95543906, 0.95704092, 0.95340032,
0.95296345, 0.95310907, 0.94874035, 0.94946847, 0.94874035,
0.94859473, 0.94917722, 0.95136158, 0.94801223, 0.94786661,
0.94801223, 0.94830348, 0.94568225, 0.94349789, 0.94262414,
0.94568225, 0.94364351, 0.94408038, 0.93956604, 0.94102228,
0.93781855, 0.93621669, 0.93432358, 0.9314111 , 0.93607106,
0.92223678, 0.92325615, 0.92121742, 0.9229649 , 0.92325615,
0.92806174, 0.92908111, 0.92849862, 0.92951798, 0.93155672,
0.93126547, 0.93199359, 0.93155672, 0.93199359, 0.93010048,
0.92995486, 0.9308286 , 0.93111985, 0.9308286 , 0.93053735,
0.93039173, 0.93199359, 0.93257609, 0.93126547, 0.92980923,
0.93039173, 0.92995486, 0.9277705 , 0.92238241, 0.92442114,
0.92267366, 0.9229649 , 0.92092617, 0.92529489, 0.92616863,
0.92645988, 0.92471239, 0.92500364, 0.92806174, 0.92995486,
0.92966361, 0.93010048, 0.92908111, 0.92922674, 0.92864424,
0.92908111, 0.93053735, 0.93330421, 0.93199359, 0.93286734,
0.93199359, 0.93199359, 0.93199359, 0.93053735, 0.9314111 ,
0.9308286 , 0.93126547, 0.93243046, 0.93315858, 0.92995486,
0.92878986, 0.92893549, 0.92878986, 0.92849862, 0.92427552,
0.92194554, 0.92005242, 0.92223678, 0.9199068 , 0.91524683,
0.91116936, 0.90563565, 0.90549002, 0.9011213 , 0.90286879,
0.90461628, 0.90942187, 0.90913062, 0.90330566, 0.90505315,
0.90723751, 0.90403378, 0.90214067, 0.90286879, 0.90286879,
0.90330566, 0.90214067, 0.89325761, 0.89121887, 0.88947138,
0.89180137, 0.88845202, 0.89267511, 0.89471385, 0.89587884,
0.89631571, 0.89981069, 0.89718946, 0.89296636, 0.89573322,
0.89908257, 0.89631571, 0.89515072, 0.89573322, 0.90214067,
0.90184942, 0.90316004, 0.89995631, 0.89995631, 0.89893694,
0.89981069, 0.9017038 , 0.90155818, 0.90126693, 0.89835445,
0.90097568, 0.89908257, 0.90199505, 0.90126693, 0.9017038 ,
0.9011213 , 0.89908257, 0.89602446, 0.89617009, 0.89558759,
0.8974807 , 0.89718946, 0.8986457 , 0.89820882, 0.90010194,
0.89995631, 0.8974807 , 0.89587884, 0.88976263, 0.89107325,
0.88277268, 0.88015145, 0.88058832, 0.88102519, 0.88000582,
0.8798602 , 0.88335518, 0.87956895, 0.88248143, 0.87811271,
0.87840396, 0.8756371 , 0.87330712, 0.87083151, 0.87112276,
0.85277414, 0.85772535, 0.85539537, 0.85976409, 0.86486093,
0.86544343, 0.87039464, 0.87199651, 0.87636522, 0.87461774,
0.87490899, 0.87549148, 0.87811271, 0.88204456, 0.88495704,
0.8865589 , 0.89354886, 0.89092762, 0.89617009, 0.89311198,
0.8913645 , 0.88903451, 0.88481142, 0.88510266, 0.88728702,
0.88699578, 0.890782 , 0.88495704, 0.88510266, 0.88481142,
0.88495704, 0.88248143, 0.88248143, 0.88102519, 0.88131644,
0.8840833 , 0.88248143, 0.88102519, 0.88175331, 0.88102519,
0.88175331, 0.88175331, 0.88524829, 0.88160769, 0.88160769,
0.8798602 , 0.88087957, 0.88306393, 0.88189894, 0.88248143,
0.88131644, 0.88058832, 0.8798602 , 0.88000582, 0.87913208,
0.87956895, 0.88146206, 0.88058832, 0.88248143, 0.87956895,
0.87956895, 0.87607398, 0.87432649, 0.87403524, 0.87869521,
0.87738459, 0.87869521, 0.87782146, 0.87418087, 0.87476336,
0.87592835, 0.87520023, 0.87461774, 0.87592835, 0.87374399,
0.89252949, 0.88903451, 0.88510266, 0.8871414 , 0.89413135,
0.89689821, 0.89922819, 0.89704383, 0.89762633, 0.89267511,
0.89471385, 0.89660696, 0.8974807 , 0.89791758, 0.90243192,
0.90417941, 0.90403378, 0.90403378, 0.90403378, 0.90345129,
0.90214067, 0.90257754, 0.90316004, 0.90374254, 0.90505315,
0.90330566, 0.90505315, 0.9059269 , 0.90738314, 0.90811126,
0.90869375, 0.91218873, 0.91510121, 0.91422746, 0.91146061,
0.90650939, 0.90854813, 0.90636377, 0.91422746, 0.91524683,
0.9199068 , 0.91830494, 0.92194554, 0.92223678, 0.92311053,
0.9193243 , 0.92107179, 0.91976118, 0.91917868, 0.91917868,
0.91743119, 0.91480996, 0.9156837 , 0.9168487 , 0.91917868,
0.9193243 , 0.91393622, 0.90782001, 0.90257754, 0.90199505,
0.90650939, 0.90403378, 0.9095675 , 0.91087811, 0.91437309,
0.91408184, 0.91815931, 0.91917868, 0.92150866, 0.92398427,
0.92573176, 0.92398427, 0.92704238, 0.92864424, 0.93053735,
0.93519732, 0.93709043, 0.93796418, 0.93781855, 0.93665356,
0.93577982, 0.93374108, 0.93359546, 0.9344692 , 0.93301296,
0.93330421, 0.93272171, 0.93476045, 0.93359546, 0.93228484,
0.93344983, 0.93490607, 0.93476045, 0.93461482, 0.93417795,
0.93927479, 0.945391 , 0.94728411, 0.95005097, 0.9526722 ,
0.94801223, 0.94553662, 0.95077909, 0.94946847, 0.95034222,
0.95092471, 0.95034222, 0.94786661, 0.95005097, 0.94975972,
0.94990534, 0.94684724, 0.94859473, 0.95310907, 0.95645842,
0.95441969, 0.95238095, 0.95238095, 0.95558468, 0.95558468,
0.9532547 , 0.95441969, 0.95398282, 0.95427406, 0.95383719,
0.95150721, 0.9484491 , 0.94975972, 0.95019659, 0.95281782,
0.95194408, 0.95150721, 0.95063346, 0.95281782, 0.95471094,
0.95369157, 0.95529343, 0.95776904, 0.95398282, 0.95383719,
0.95471094, 0.95310907, 0.95369157, 0.95310907, 0.95136158,
0.94946847, 0.94874035, 0.94699286, 0.95238095, 0.95165283,
0.96082714, 0.96621523, 0.97014708, 0.96810834, 0.9708752 ,
0.96839959, 0.96796272, 0.96767147, 0.96985583, 0.97233144,
0.96956458, 0.97058395, 0.96796272, 0.97218582, 0.97058395,
0.97174894, 0.97305956, 0.97305956, 0.97582642, 0.97742828,
0.97670016, 0.97830202, 0.9787389 , 0.97830202, 0.9787389 ,
0.97611766, 0.97699141, 0.97407893, 0.97597204, 0.97582642,
0.97131207, 0.97568079, 0.97713703, 0.97568079, 0.97407893,
0.97626329, 0.97553517, 0.9775739 , 0.97684578, 0.97466142,
0.97495267, 0.97626329, 0.97684578, 0.97568079, 0.97597204,
0.97655454, 0.97582642, 0.97538954, 0.97728266, 0.97728266,
0.97713703, 0.97975826, 0.9812145 , 0.97946702, 0.98019514,
0.9812145 , 0.9787389 , 0.97786515, 0.9787389 , 0.97742828,
0.97670016, 0.97684578, 0.97524392, 0.97728266, 0.97801078,
0.97917577, 0.97728266, 0.97568079, 0.97786515, 0.9781564 ,
0.98077763, 0.98004951, 0.9787389 , 0.98310762, 0.981797 ,
0.9812145 , 0.9702927 , 0.96650648, 0.97466142, 0.97276831,
0.97553517, 0.97466142, 0.97335081, 0.97291394, 0.97437018,
0.9739333 , 0.98092326, 1. , 0.99606815, 0.99286442,
0.9927188 , 0.99111694, 0.99315567, 0.99111694, 0.9884957 ,
0.98470948, 0.98616572, 0.98587447, 0.98616572, 0.98864133,
0.98674822, 0.98951507, 0.99024319, 0.99053444, 0.9890782 ,
0.9890782 , 0.98922382, 0.98747634, 0.9890782 , 0.98820446,
0.98514635, 0.98296199, 0.98281637, 0.98412698, 0.98267074,
0.98456386, 0.98208825, 0.98165138, 0.9708752 , 0.9775739 ,
0.97932139, 0.97859327, 0.98019514, 0.98034076, 0.98106888,
0.98048638, 0.98150575, 0.98194262, 0.98296199, 0.981797 ,
0.98296199, 0.98223387, 0.98267074, 0.9823795 , 0.98310762,
0.98441823, 0.98354449, 0.98296199, 0.98281637, 0.98194262,
0.98339886, 0.98339886, 0.98456386, 0.98412698, 0.98558322,
0.98470948, 0.97888452, 0.96985583, 0.97437018, 0.97728266,
0.9787389 , 0.98703946, 0.98034076, 0.97859327, 0.98048638,
0.97830202, 0.97932139, 0.98980632, 0.99038882, 0.99417504,
0.99082569, 0.9933013 , 0.99199068, 0.9890782 , 0.9921363 ,
0.99097131, 0.99009757, 0.98936945, 0.99009757, 0.98878695,
0.98631134, 0.98660259, 0.99038882, 0.98660259, 0.98354449,
0.97990389, 0.97888452, 0.97932139, 0.97422455, 0.97713703,
0.97771953, 0.97655454, 0.97684578, 0.97859327, 0.97204019,
0.97072958, 0.97174894, 0.97320518, 0.97335081, 0.97305956,
0.97160332, 0.97276831, 0.97276831, 0.9714577 , 0.97160332,
0.97262269, 0.96985583, 0.96752585, 0.96854522, 0.96927334,
0.96985583, 0.96956458, 0.97102082, 0.97072958, 0.97378768,
0.97131207, 0.97058395, 0.97043833, 0.97335081, 0.97422455,
0.97291394, 0.97349643, 0.97320518, 0.97218582, 0.9702927 ,
0.96927334, 0.96636086, 0.96767147, 0.96927334, 0.96636086,
0.96403087, 0.96126402, 0.96272026, 0.96577836, 0.96636086,
0.97189457, 0.97014708, 0.96563274, 0.96941896, 0.96912771,
0.97116645, 0.96912771, 0.97378768, 0.97699141, 0.97713703,
0.97568079, 0.97480705, 0.97335081, 0.97131207, 0.97116645,
0.97072958, 0.97189457, 0.97058395, 0.97102082, 0.96796272,
0.96592398, 0.96898209, 0.96548711, 0.96839959, 0.96738022,
0.96796272, 0.96912771, 0.96694335, 0.97102082, 0.96985583,
0.96636086, 0.9672346 , 0.9672346 , 0.96548711, 0.9672346 ,
0.96577836, 0.96825397, 0.96694335, 0.96752585, 0.96825397,
0.96854522, 0.97058395, 0.97058395, 0.96927334, 0.97072958,
0.96767147, 0.97276831, 0.97116645, 0.97000146, 0.97014708,
0.96985583, 0.96373962, 0.96839959, 0.96694335, 0.96796272,
0.96738022, 0.96752585, 0.96898209, 0.96927334, 0.9714577 ,
0.97174894, 0.96883646, 0.96839959, 0.96854522, 0.97131207,
0.97000146, 0.97043833, 0.96854522, 0.97174894, 0.97189457,
0.97204019, 0.97189457, 0.97204019, 0.97233144, 0.97160332,
0.97932139, 0.98470948, 0.98019514, 0.98427261, 0.9812145 ,
0.98412698, 0.9884957 , 0.99097131, 0.9927188 , 0.98936945,
0.98835008, 0.98354449, 0.98339886, 0.98412698, 0.98354449,
0.98267074, 0.98019514, 0.97233144, 0.97626329, 0.97407893,
0.97247706, 0.96985583, 0.96548711, 0.96898209, 0.96621523,
0.96941896, 0.97174894, 0.96883646, 0.96883646, 0.9702927 ,
0.96941896, 0.97058395, 0.96869084, 0.96796272, 0.96577836,
0.96082714, 0.96286588, 0.96140964, 0.96155526, 0.96213776,
0.96286588, 0.96592398, 0.96694335, 0.96286588, 0.96432212,
0.96592398, 0.96199214, 0.96272026, 0.96344838, 0.96446774,
0.96213776, 0.96315713, 0.96140964, 0.96257463, 0.96519586,
0.96257463, 0.95587593, 0.95412844, 0.95383719, 0.95136158,
0.94611912, 0.95383719, 0.95529343, 0.95733217, 0.95820591,
0.96388525, 0.96606961, 0.96199214, 0.96446774, 0.97174894,
0.97262269, 0.97364206, 0.97305956, 0.97538954, 0.97233144,
0.97917577, 0.97742828, 0.97276831, 0.96883646, 0.97014708,
0.97670016, 0.97713703, 0.9781564 , 0.97495267, 0.96941896,
0.97160332, 0.97072958, 0.96941896, 0.96985583, 0.96898209,
0.97014708, 0.97160332, 0.97072958, 0.97072958, 0.97072958,
0.9702927 , 0.97102082, 0.96956458, 0.97043833, 0.97043833]
tensor = tf.convert_to_tensor(a,dtype=tf.float32,)
tensor