Issue
I have to design a neural network that takes two inputs X_1
and X_2
. The layer transforms them to fixed-size vectors(10D) and then sums them in the following manner
class my_lyr(tf.keras.layers.Layer):
def __init__(self):
pass
def call(self, X_1, X_2):
return X_1 @ self.w1 + X_2 @ self.w2
However, I need to know the input shape of X_1
and X_2
before I initialize w1
and w2
.
I'm not sure how can I declare w2
in build
.
def build(self, input_shape):
self.w1 = self.add_weight('w1', shape=[input_shape[-1],10])
// self.w2 = ?????
I want to know how to build methods are usually written in such cases.
Solution
If you've two input of such layer, then you can simply initialize your weights something like as follows
import tensorflow as tf
from tensorflow import keras
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super(Linear, self).__init__()
self.units = units
def build(self, input_shape):
self.wa = self.add_weight(
shape=(input_shape[0][-1], self.units),
initializer="random_normal",
trainable=True,
)
self.wb = self.add_weight(
shape=(input_shape[1][-1], self.units),
initializer="random_normal",
trainable=True,
)
def call(self, inputs):
return tf.matmul(inputs[0], self.wa) + tf.matmul(inputs[1], self.wb)
Passing inputs
x = tf.random.normal(shape=(2,2))
linear_layer = Linear(32)
linear_layer([x, x])
<tf.Tensor: shape=(2, 32), dtype=float32, numpy=
array([[-0.08829461, -0.01605312, -0.04368614, -0.08116315, -0.01521384,
0.01132785, 0.10704445, -0.10873697, -0.0525714 , 0.07684848,
0.04586978, 0.01315852, 0.01369547, 0.07404792, 0.10313608,
-0.10851607, 0.04091477, -0.01723676, -0.0326797 , 0.03598418,
-0.11335816, -0.10044714, 0.13555384, 0.01689356, 0.02631954,
0.08226107, -0.08765724, -0.05981663, 0.00531629, 0.02930426,
0.04155847, 0.05339598],
[ 0.20617458, -0.05936547, 0.01735754, -0.06575315, 0.10090968,
-0.07796012, -0.1956767 , -0.03406558, 0.18604615, -0.03547171,
0.02784208, 0.0471364 , -0.10712875, -0.07869454, -0.19457275,
0.13593757, -0.14659101, 0.0384632 , 0.02344182, -0.03861775,
0.08948556, 0.09225713, -0.17395493, 0.10021958, -0.09210777,
-0.09865301, 0.2536609 , -0.02547608, 0.02885125, -0.01271547,
-0.10340843, -0.0338558 ]], dtype=float32)>
Answered By - M.Innat
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.