How to fix CppException forward() Expected a value of type 'Tensor' in pytorch android, load by Module.load(''):
android log:
mel, melInputTensor = org.pytorch.Tensor$Tensor_float32, [1, 840]
Caused by: com.facebook.jni.CppException: forward() Expected a value of type 'Tensor' for argument 'x'
but instead found type 'Dynamic<128>[Dynamic<1>,]'.
Position: 1
Declaration: forward(__torch__.models.preprocess.___torch_mangle_11.AugmentMelSTFT self, Tensor x) -> Tensor
Exception raised from checkArg at /Users/huydo/Storage/mine/pytorch/aten/src/ATen/core/function_schema_inl.h:340 (most recent call first):
(no backtrace available)
at org.pytorch.NativePeer.forward(Native Method)
at org.pytorch.Module.forward(
android invoke:
val wavBatch = 1
val wavLength = 840
val dummyInput = dummyInput(wavBatch, wavLength, 0.0f)
val inputShape = longArrayOf(wavBatch.toLong(), wavLength.toLong())
melInputTensor = Tensor.fromBlob(dummyInput, inputShape)
if (DEBUG) Log.i(TAG,
"mel, melInputTensor = " + melInputTensor?.javaClass?.name + ", " + melInputTensor?.shape().contentToString()
// Caused by: com.facebook.jni.CppException: forward() Expected a value of type 'Tensor' for argument 'x'
// but instead found type 'Dynamic<128>[Dynamic<1>,]'.
val forward = melModel!!.forward(IValue.listFrom(melInputTensor))
// Caused by: java.lang.IllegalStateException: Expected IValue type Tuple, actual type Tensor
//val forward = melModel!!.forward(IValue.from(melInputTensor))
but the same model works fine in python, load by torch.load(''):
python log:
inputs 0 134400 [-1.4953613e-03 -1.6479492e-03 -1.4648438e-03 ...
inputs 1 torch.Size([1, 134400]) tensor([[-1.4954e-03, -1.6479e-03, -1.4648e-03, . <class 'torch.Tensor'>
inputs 2 torch.Size([1, 128, 420]) tensor([[[-0.7647, -0.5746, -0.6255, ..., -1.3985
outputs 0 torch.Size([1, 527]) tensor([[ -3.1250, -6.5625, -7.0312, -7.7500,
python invoke:
# model to preprocess waveform into mel spectrograms
mel = load_model_from_uri(mel_ptmobile_name)
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
if DEBUG: print('inputs 0', len(waveform), str(waveform)[:50])
waveform = torch.from_numpy(waveform[None, :]).to(device)
if DEBUG: print('inputs 1', waveform.shape, str(waveform)[:50], type(waveform))
# our models are trained in half precision mode (torch.float16)
# run on cuda with torch.float16 to get the best performance
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
with torch.no_grad(), autocast(device_type=device.type) if cuda else nullcontext():
spec = mel(waveform)
if DEBUG: print('inputs 2', spec.shape, str(spec)[:50])
python model:
class AugmentMelSTFT(nn.Module):
def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):
# adapted from:
self.win_length = win_length
self.n_mels = n_mels
self.n_fft = n_fft = sr
self.fmin = fmin
if fmax is None:
fmax = sr // 2 - fmax_aug_range // 2
if DEBUG: print(f"Warning: FMAX is None setting to {fmax} ")
self.fmax = fmax
self.hopsize = hopsize
torch.hann_window(win_length, periodic=False),
assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
self.fmin_aug_range = fmin_aug_range
self.fmax_aug_range = fmax_aug_range
self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
if freqm == 0:
self.freqm = torch.nn.Identity()
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
if timem == 0:
self.timem = torch.nn.Identity()
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
def forward(self, x):
if onnx_conf.DEBUG: print('mel.forward,', x.shape, x[0][0].dtype, type(x))
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
center=True, normalized=False, window=self.window, return_complex=False)
# x = stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
# center=True, normalized=False, window=self.window, return_complex=False)
x = (x ** 2).sum(dim=-1) # power mag
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
# don't augment eval data
if not
fmin = self.fmin
fmax = self.fmax
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft,,
fmin, fmax, vtln_low=100.0, vtln_high=-500.,
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
with torch.cuda.amp.autocast(enabled=False):
melspec = torch.matmul(mel_basis, x)
melspec = (melspec + 0.00001).log()
melspec = self.freqm(melspec)
melspec = self.timem(melspec)
melspec = (melspec + 4.5) / 5. # fast normalization
return melspec
solved, by changed android invoke to:
val wavLength = 134400
val forward = melModel!!.forward(IValue.from(melInputTensor))