I trained a ProGAN agent using this PyTorch reimplementation, and I saved the agent as a .pth
. Now I need to convert the agent into the .onnx
format, which I am doing using this scipt:
from torch.autograd import Variable
import torch.onnx
import torchvision
import torch
device = torch.device("cuda")
dummy_input = torch.randn(1, 3, 64, 64)
state_dict = torch.load("GAN_agent.pth", map_location = device)
torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
Once I run it, I get the error AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
(full prompt below). As far as I understood, the problem is that converting the agent into .onnx requires more information. Am I missing something?
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-c64481d4eddd> in <module>
10 state_dict = torch.load("GAN_agent.pth", map_location = device)
11
---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
146 operator_export_type, opset_version, _retain_param_name,
147 do_constant_folding, example_outputs,
--> 148 strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
149
150
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
64 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
65 example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
67
68
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
414 example_outputs, propagate,
415 _retain_param_name, do_constant_folding,
--> 416 fixed_batch_size=fixed_batch_size)
417
418 # TODO: Don't allocate a in-memory string for the protobuf
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
277 model.graph, tuple(in_vars), False, propagate)
278 else:
--> 279 graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
280 state_dict = _unique_state_dict(model)
281 params = list(state_dict.values())
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)
226 # A basic sanity check: make sure the state_dict keys are the same
227 # before and after running the model. Fail fast!
--> 228 orig_state_dict_keys = _unique_state_dict(model).keys()
229
230 # By default, training=False, which is good because running a model in
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)
283 # id(v) doesn't work with it. So we always get the Parameter or Buffer
284 # as values, and deduplicate the params using Parameters and Buffers
--> 285 state_dict = module.state_dict(keep_vars=True)
286 filtered_dict = type(state_dict)()
287 seen_ids = set()
AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
Files you have there are state_dict
, which are simply mappings of layer name to tensor
weights biases and a-like (see here for more thorough introduction).
What that means is that you need a model so those saved weights and biases can be mapped upon, but first things first:
Clone the repository where model definitions are located and open file /pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py
. We need some modifications in order for it to work with onnx
. onnx
exporter requires input
to be passed as torch.tensor
only (or list
/dict
of those), while Generator
class needs int
and float
arguments).
Simple solution it to slightly modify forward
function (line 80
in the file, you can verify it on GitHub) to the following:
def forward(self, x, depth, alpha):
"""
forward pass of the Generator
:param x: input noise
:param depth: current depth from where output is required
:param alpha: value of alpha for fade-in effect
:return: y => output
"""
# THOSE TWO LINES WERE ADDED
# We will pas tensors but unpack them here to `int` and `float`
depth = depth.item()
alpha = alpha.item()
# THOSE TWO LINES WERE ADDED
assert depth < self.depth, "Requested output depth cannot be produced"
y = self.initial_block(x)
if depth > 0:
for block in self.layers[: depth - 1]:
y = block(y)
residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))
straight = self.rgb_converters[depth](self.layers[depth - 1](y))
out = (alpha * straight) + ((1 - alpha) * residual)
else:
out = self.rgb_converters[0](y)
return out
Only unpacking via item()
was added here. Every input which is not of Tensor
type should be packed as one in function definition and unpacked ASAP at the top of your function. It will not destroy your created checkpoint so no worries as it's just layer-weight
mapping.
Place this script in /pro_gan_pytorch
(where README.md
is located as well):
import torch
from pro_gan_pytorch import PRO_GAN as pg
gen = torch.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth"))
module = gen.module.to("cpu")
# Arguments like depth and alpha may need to be changed
dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1]))
torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)
Please notice a few things:
state_dict
only.torch.nn.DataParallel
is needed as that's what the model was trained on (not sure about your case, please adjust accordingly). After loading we can get the module itself via module
attribute.CPU
, no need for GPU
here I think. You could cast everything to GPU
if you so insist though.512
elements.Run it and your .onnx
file should be there.
Oh, and as you are after different checkpoint you may want to follow similar procedure, though no guarantees everything will work fine (it does look like it though).