I'm trying to get all layers from within a custom layer subclassed from tf.keras.layers but I am having difficulties with this. The end goal is to create a DAG (Directed Acyclic Graph) that has layers from tf.keras.layers.* Here is an example:
from tensorflow import keras
...
class ResidualBlock(keras.layers.Layer):
def __init__(
self,
filters0: int,
filters1: int,
activation: str = "leaky_relu",
**kwargs,
) -> None:
super().__init__()
self.conv0 = keras.layers.Conv2D(
filters0, 1, activation=activation, **kwargs)
self.conv1 = keras.layers.Conv2D(
filters1, 3, activation=activation, **kwargs)
def call(self, inputs, training=False):
x = self.conv0(inputs, training=training)
x = self.conv1(x, training=training)
x = inputs + x
return x
rb = ResidualBlock(2, 3)
new_model = Sequential([rb, keras.layers.Dense(200)])
convert_to_DAG(new_model)
I want to get something like this:
[{'type': 'ResidualBlock', 'children': ['conv2D_1']},
{'type': 'conv2D_1', 'children': ['conv2D_2', 'residual'},
{'type': 'conv2D_2', 'children': ['Dense_1']},
...
]
I've seen all related answers like: How to access recursive layers of custom layer in tensorflow keras which accesses layers from a model subclassed from tf Model, NOT tf layers.Layer
The following code from Check which are the next layers in a tensorflow keras model which breaks a model based on nodes but it does not recursively follow each layer to its base layer/operations (which I require).
def get_layer_summary_with_connections(layer, relevant_nodes):
info = {}
connections = []
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
continue
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
connections.append(inbound_layer.name)
name = layer.name
info['type'] = layer.__class__.__name__
info['parents'] = connections
return info
The end result should be a DAG that contains all base layers + operations like this: DAG end result. All Layers are base layers + operations
Thank you for any help. I can clarify if anything is unclear
model.layersbut that is not possible in this case.