Productising Tensorflow with `swish_f32` errors

Problem

Recently I was trying to freeze & optimize a model for deployment, however I was running into the following issue:
Op type not registered 'swish_f32' in binary running

Weird, I was pretty sure I was just using tf.nn.swish in my model why was it giving me issues with other Tensorflow tooling? - Such as convert_variables_to_constants and TransformGraph.

Turns out it was because TF’s version is optimised so stores it into the graph definitions library of functions.
I believe this is fixed in newer tensorflow versions, however I’m using TF 1.10 and have other teams expecting to use my model with TF 1.10.
So I can’t just upgrade, instead we will re-write the graph to replace nodes with:
op: "swish_f32"
with two operation nodes to represent:
x * sigmoid(x)

Solution

Below we offer a solution which takes a graph defintion and re-writes it to replace any swish_f32 operations with two extra nodes to represent x * sigmoid(x) in the graph definition.

We could make optimisations to the solution code if you haven’t added a custom name in the tf.nn.swish call, however we will presume you have or could have, or if it’s going into a pipeline another research might have.

For that we will re-use the original names and then overwrite the input names in the second pass.

In our pipeline we check predicitions before, and after, optimising models.
I suggest you do the same to make sure you see the same results as we do.

def replace_swish_f32(graph_def):
   """ First replace all old swish nodes, then use the node names to switch other nodes inputs.

   Returns:
       Transformed Graph Def
   """
   swish32_node_names = set()
   modified_graph_def = tf.GraphDef()
   new_nodes = []
   for node in graph_def.node:
       if "swish_f32" == node.op:
           swish32_node_names.add(node.name)
           new_nodes.extend(get_swish_nodes(node.input[0], node.name))
       else:
           new_nodes.append(node)

   modified_graph_def.node.extend(new_nodes)

   # Replace references to deleted nodes with created nodes.
   for node in modified_graph_def.node:
       inp_names = []
       for inp in node.input:
           if inp in swish32_node_names:
               inp_names.append(inp + "_Mul")
           else:
               inp_names.append(inp)

       del node.input[:]
       node.input.extend(inp_names)

   return modified_graph_def


def get_swish_nodes(input_node_name, name):
   sigmoid_node = tf.NodeDef(
       name=name + "_Sigmoid", op="Sigmoid", attr={"T": tf.AttrValue(type=tf.float32.as_datatype_enum)},
   )
   sigmoid_node.input.append(input_node_name)
   mul_node = tf.NodeDef(name=name + "_Mul", op="Mul", attr={"T": tf.AttrValue(type=tf.float32.as_datatype_enum)})
   mul_node.input.extend([input_node_name, sigmoid_node.name])
   return [sigmoid_node, mul_node]
Applied ML Lead

My research interests include computer vision, multi-task learning and AI-enabled products.

comments powered by Disqus