Issue
In this example, we see the following implementation of nn.Module
:
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
However, in the docs we have that 'forward(*input)
' "Should be overridden by all subclasses."
Why is that not then the case in this example?
Solution
This Net
module is meant to be used via two separate interfaces encoder
and decode
, at least it seems so... Since it doesn't have a forward
implementation, then yes it is improperly inheriting from nn.Module
. However, the code is still "valid", and will run properly but may have some side effects if you are using forward hooks.
The standard way of performing inference on a nn.Module
is to call the object, i.e. calling the __call__
function. This __call__
function is implemented by the parent class nn.Module
and will in turn do two things:
- handle forward hooks before or after the inference call
- call the
forward
function of the class.
The __call__
function acts as a wrapper of forward
.
So for this reason the forward
function is expected to be overridden by the user-defined nn.Module
. The only caveat of violating this design pattern is that it will effectively ignore any hooks applied on the nn.Module
.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.