Search code examples
pythonandroiddeep-learningpytorchtensor

How to fix CppException forward() Expected a value of type 'Tensor' in pytorch android, but the same model works fine in python


How to fix CppException forward() Expected a value of type 'Tensor' in pytorch android, load by Module.load('mel_ptmobile_v2.pt'):

android log:

mel, melInputTensor = org.pytorch.Tensor$Tensor_float32, [1, 840]
Caused by: com.facebook.jni.CppException: forward() Expected a value of type 'Tensor' for argument 'x' 
           but instead found type 'Dynamic<128>[Dynamic<1>,]'.
                 Position: 1
                 Declaration: forward(__torch__.models.preprocess.___torch_mangle_11.AugmentMelSTFT self, Tensor x) -> Tensor
                 Exception raised from checkArg at /Users/huydo/Storage/mine/pytorch/aten/src/ATen/core/function_schema_inl.h:340 (most recent call first):
                 (no backtrace available)
                    at org.pytorch.NativePeer.forward(Native Method)
                    at org.pytorch.Module.forward(Module.java:52)

android invoke:

val wavBatch = 1
val wavLength = 840
val dummyInput = dummyInput(wavBatch, wavLength, 0.0f)
val inputShape = longArrayOf(wavBatch.toLong(), wavLength.toLong())
melInputTensor = Tensor.fromBlob(dummyInput, inputShape)
if (DEBUG) Log.i(TAG,
    "mel, melInputTensor = " + melInputTensor?.javaClass?.name + ", " + melInputTensor?.shape().contentToString()
)

// Caused by: com.facebook.jni.CppException: forward() Expected a value of type 'Tensor' for argument 'x'
// but instead found type 'Dynamic<128>[Dynamic<1>,]'.
val forward = melModel!!.forward(IValue.listFrom(melInputTensor))
// Caused by: java.lang.IllegalStateException: Expected IValue type Tuple, actual type Tensor
//val forward = melModel!!.forward(IValue.from(melInputTensor))

but the same model works fine in python, load by torch.load('mel_ptmobile_v2.pt'):

python log:

inputs 0 134400 [-1.4953613e-03 -1.6479492e-03 -1.4648438e-03 ... 
inputs 1 torch.Size([1, 134400]) tensor([[-1.4954e-03, -1.6479e-03, -1.4648e-03,  . <class 'torch.Tensor'>
inputs 2 torch.Size([1, 128, 420]) tensor([[[-0.7647, -0.5746, -0.6255,  ..., -1.3985
outputs 0 torch.Size([1, 527]) tensor([[ -3.1250,  -6.5625,  -7.0312,  -7.7500, 

python invoke:

# model to preprocess waveform into mel spectrograms
mel = load_model_from_uri(mel_ptmobile_name)

(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
if DEBUG: print('inputs 0', len(waveform), str(waveform)[:50])
waveform = torch.from_numpy(waveform[None, :]).to(device)
if DEBUG: print('inputs 1', waveform.shape, str(waveform)[:50], type(waveform))

# our models are trained in half precision mode (torch.float16)
# run on cuda with torch.float16 to get the best performance
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
with torch.no_grad(), autocast(device_type=device.type) if cuda else nullcontext():
    spec = mel(waveform)
    if DEBUG: print('inputs 2', spec.shape, str(spec)[:50])

python model:

class AugmentMelSTFT(nn.Module):
    def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
                 fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):
        torch.nn.Module.__init__(self)
        # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e

        self.win_length = win_length
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.sr = sr
        self.fmin = fmin
        if fmax is None:
            fmax = sr // 2 - fmax_aug_range // 2
            if DEBUG: print(f"Warning: FMAX is None setting to {fmax} ")
        self.fmax = fmax
        self.hopsize = hopsize
        self.register_buffer('window',
                             torch.hann_window(win_length, periodic=False),
                             persistent=False)
        assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
        assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
        self.fmin_aug_range = fmin_aug_range
        self.fmax_aug_range = fmax_aug_range

        self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
        if freqm == 0:
            self.freqm = torch.nn.Identity()
        else:
            self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
        if timem == 0:
            self.timem = torch.nn.Identity()
        else:
            self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)

    def forward(self, x):
        if onnx_conf.DEBUG: print('mel.forward,', x.shape, x[0][0].dtype, type(x))
        x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
        x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
                       center=True, normalized=False, window=self.window, return_complex=False)
        # x = stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
        #          center=True, normalized=False, window=self.window, return_complex=False)
        x = (x ** 2).sum(dim=-1)  # power mag
        fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
        fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
        # don't augment eval data
        if not self.training:
            fmin = self.fmin
            fmax = self.fmax

        mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr,
                                                                 fmin, fmax, vtln_low=100.0, vtln_high=-500.,
                                                                 vtln_warp_factor=1.0)
        mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
                                    device=x.device)
        with torch.cuda.amp.autocast(enabled=False):
            melspec = torch.matmul(mel_basis, x)

        melspec = (melspec + 0.00001).log()

        if self.training:
            melspec = self.freqm(melspec)
            melspec = self.timem(melspec)

        melspec = (melspec + 4.5) / 5.  # fast normalization

        return melspec

Solution

  • solved, by changed android invoke to:

    val wavLength = 134400
    ...
    val forward = melModel!!.forward(IValue.from(melInputTensor))