import time import onnx import queue import onnxsim import numpy as np import onnxruntime as ort import onnx_graphsurgeon as gs def two_matmul(): batch_size, num_output_feat = 300, 8096 num_input_1_feat, num_input_2_feat = 1024, 4096 dst_onnx_path = 'two_matmul.onnx' # matmul node 1 input_1 = gs.Variable( name='input_1', dtype=np.float32, shape=(batch_size, num_input_1_feat) ) matmul_1_weight = gs.Constant( name=f'matmul_1_weight', values=np.random.rand(num_input_1_feat, num_output_feat).astype(np.float32) ) matmul_1_output = gs.Variable( name='matmul_1_output', dtype=np.float32, shape=(batch_size, num_output_feat) ) matmul_node_1 = gs.Node( op='MatMul', name='MatMul_1', inputs=[input_1, matmul_1_weight], outputs=[matmul_1_output] ) # matmul node 2 input_2 = gs.Variable( name='input_2', dtype=np.float32, shape=(batch_size, num_input_2_feat) ) matmul_2_weight = gs.Constant( name=f'matmul_2_weight', values=np.random.rand(num_input_2_feat, num_output_feat).astype(np.float32) ) matmul_2_output = gs.Variable( name='matmul_2_output', dtype=np.float32, shape=(batch_size, num_output_feat) ) matmul_node_2 = gs.Node( op='MatMul', name='MatMul_2', inputs=[input_2, matmul_2_weight], outputs=[matmul_2_output] ) graph = gs.Graph( nodes=[matmul_node_1, matmul_node_2], inputs=[input_1, input_2], outputs=[matmul_1_output, matmul_2_output], name='two_matmul', opset=11 ) dst_onnx_model = gs.export_onnx(graph) dst_onnx_model= onnx.shape_inference.infer_shapes(dst_onnx_model, data_prop=True) onnx.checker.check_model(dst_onnx_model) onnx.save(dst_onnx_model, dst_onnx_path) def two_matmul_if(): batch_size, num_output_feat = 300, 8096 num_input_1_feat, num_input_2_feat = 1024, 4096 dst_onnx_path = 'two_matmul_if.onnx' # matmul node(branch) 1 input_1 = gs.Variable( name='input_1', dtype=np.float32, shape=(batch_size, num_input_1_feat) ) matmul_1_weight = gs.Constant( name=f'matmul_1_weight', values=np.random.rand(num_input_1_feat, num_output_feat).astype(np.float32) ) matmul_1_output = gs.Variable( name='matmul_1_output', dtype=np.float32, shape=(batch_size, num_output_feat) ) matmul_node_1 = gs.Node( op='MatMul', name='MatMul_1', inputs=[input_1, matmul_1_weight], outputs=[matmul_1_output] ) subgraph_1 = gs.Graph( nodes=[matmul_node_1], inputs=[input_1], outputs=[matmul_1_output], name='subgraph_1', opset=13 ) # matmul node(branch) 2 input_2 = gs.Variable( name='input_2', dtype=np.float32, shape=(batch_size, num_input_2_feat) ) matmul_2_weight = gs.Constant( name=f'matmul_2_weight', values=np.random.rand(num_input_2_feat, num_output_feat).astype(np.float32) ) matmul_2_output = gs.Variable( name='matmul_2_output', dtype=np.float32, shape=(batch_size, num_output_feat) ) matmul_node_2 = gs.Node( op='MatMul', name='MatMul_2', inputs=[input_2, matmul_2_weight], outputs=[matmul_2_output] ) subgraph_2 = gs.Graph( nodes=[matmul_node_2], inputs=[input_2], outputs=[matmul_2_output], name='subgraph_2', opset=13 ) # equal node input_flag = gs.Variable( name='input_flag', dtype=np.int32, shape=() ) equal_target = gs.Constant( name=f'Equal_target', values=np.array(1, dtype=np.int32) # values=np.array([0], dtype=np.int32).reshape((1, )) ) equal_output = gs.Variable( name='equal_output', dtype=np.bool_, shape=() ) equal_node = gs.Node( op='Equal', name='Equal', inputs=[input_flag, equal_target], outputs=[equal_output] ) # if node if_output = gs.Variable( name='if_output', dtype=np.float32, shape=(batch_size, num_output_feat) ) if_node = gs.Node( op='If', name='If', attrs={'then_branch': subgraph_1, 'else_branch': subgraph_2}, inputs=[equal_output], outputs=[if_output] ) # total graph graph = gs.Graph( nodes=[equal_node, if_node], inputs=[input_1, input_2, input_flag], outputs=[if_output], name='two_matmul_if', opset=13 ) dst_onnx_model = gs.export_onnx(graph) dst_onnx_model= onnx.shape_inference.infer_shapes(dst_onnx_model, data_prop=True) onnx.checker.check_model(dst_onnx_model) onnx.save(dst_onnx_model, dst_onnx_path) if __name__ == '__main__': two_matmul() two_matmul_if()