Productising Tensorflow with `swish_f32` errors
Problem
Recently I was trying to freeze & optimize a model for deployment,
however I was running into the following issue:
Op type not registered 'swish_f32' in binary running
Weird, I was pretty sure I was just using tf.nn.swish in my model why was it giving
me issues with other Tensorflow tooling? - Such as convert_variables_to_constants and TransformGraph.
Turns out it was because TF’s version is
optimised so
stores it into the graph definitions library of functions.
I believe this is fixed in newer tensorflow versions, however I’m using TF 1.10
and have other teams expecting to use my model with TF 1.10.
So I can’t just upgrade, instead we will re-write the graph to replace nodes with:
op: "swish_f32"
with two operation nodes to represent:
x * sigmoid(x)
Solution
Below we offer a solution which takes a graph defintion and re-writes it to
replace any swish_f32 operations with two extra nodes to represent x * sigmoid(x) in the graph definition.
We could make optimisations to the solution code if you haven’t added a custom name in the tf.nn.swish call,
however we will presume you have or could have, or if it’s going into a pipeline another research might have.
For that we will re-use the original names and then overwrite the input names in the second pass.
In our pipeline we check predicitions before, and after, optimising models.
I suggest you do the same to make sure you see the same results as we do.
def replace_swish_f32(graph_def):
""" First replace all old swish nodes, then use the node names to switch other nodes inputs.
Returns:
Transformed Graph Def
"""
swish32_node_names = set()
modified_graph_def = tf.GraphDef()
new_nodes = []
for node in graph_def.node:
if "swish_f32" == node.op:
swish32_node_names.add(node.name)
new_nodes.extend(get_swish_nodes(node.input[0], node.name))
else:
new_nodes.append(node)
modified_graph_def.node.extend(new_nodes)
# Replace references to deleted nodes with created nodes.
for node in modified_graph_def.node:
inp_names = []
for inp in node.input:
if inp in swish32_node_names:
inp_names.append(inp + "_Mul")
else:
inp_names.append(inp)
del node.input[:]
node.input.extend(inp_names)
return modified_graph_def
def get_swish_nodes(input_node_name, name):
sigmoid_node = tf.NodeDef(
name=name + "_Sigmoid", op="Sigmoid", attr={"T": tf.AttrValue(type=tf.float32.as_datatype_enum)},
)
sigmoid_node.input.append(input_node_name)
mul_node = tf.NodeDef(name=name + "_Mul", op="Mul", attr={"T": tf.AttrValue(type=tf.float32.as_datatype_enum)})
mul_node.input.extend([input_node_name, sigmoid_node.name])
return [sigmoid_node, mul_node]