Search code examples
rustdeep-learningpytorchtorchlibtorch

The shape of the Tensor output from the "forward_t" method in "tch-rs" does not match the size of the input tensor in "cross_entropy_for_logits"


let loss = net.forward_t(&batch_images, true).cross_entropy_for_logits(&batch_lbls);

The batch_images here refers to 128 grayscale images of size 512x512, which are converted into tensors and reshaped into two-dimensional tensors with a shape of [128, 262144]. The shape of batch_lbls is [128]. Then here is the code for the neural network:

use tch::{Tensor, nn, nn::ModuleT};

#[derive(Debug)]
pub struct ConvNN {
    conv1: nn::Conv2D,
    conv2: nn::Conv2D,
    fc1: nn::Linear,
    fc2: nn::Linear,
}

impl ConvNN {
    pub fn new(vs: &nn::Path) -> ConvNN {
        let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
        let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
        let fc1 = nn::linear(vs, 1024, 1024, Default::default());
        let fc2 = nn::linear(vs, 1024, 5, Default::default());
        ConvNN { conv1, conv2, fc1, fc2 }
    }
}

impl ModuleT for ConvNN {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
        xs.view([-1, 1, 512, 512])
            .apply(&self.conv1)
            .max_pool2d_default(2)
            .apply(&self.conv2)
            .max_pool2d_default(2)
            .view([-1, 1024])
            .apply(&self.fc1)
            .relu()
            .dropout(0.5, train)
            .apply(&self.fc2)
    }
}

Error occurred while running the program:

thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Torch("size mismatch (got input: [125000, 5], target: [128])\nException raised from meta at /build/python-pytorch/src/pytorch/aten/src/ATen/native/LossNLL.cpp:52 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xa2 (0x7fa9bd5b8852 in /usr/lib/libc10.so)\nframe #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe1 (0x7fa9bd57e86f in /usr/lib/libc10.so)\nframe #2: at::meta::structured_nll_loss_forward::meta(at::Tensor const&, at::Tensor const&, at::OptionalTensorRef, long, long) + 0x43a (0x7fa9bedc890a in /usr/lib/libtorch_cpu.so)\nframe #3: <unknown function> + 0x257cc83 (0x7fa9bfb7cc83 in /usr/lib/libtorch_cpu.so)\nframe #4: <unknown function> + 0x257cda4 (0x7fa9bfb7cda4 in /usr/lib/libtorch_cpu.so)\nframe #5: at::_ops::nll_loss_forward::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, c10::SymInt) + 0x206 (0x7fa9bf7a50e6 in /usr/lib/libtorch_cpu.so)\nframe #6: <unknown function> + 0x48759ff (0x7fa9c1e759ff in /usr/lib/libtorch_cpu.so)\nframe #7: <unknown function> + 0x48761cc (0x7fa9c1e761cc in /usr/lib/libtorch_cpu.so)\nframe #8: at::_ops::nll_loss_forward::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, c10::SymInt) + 0x219 (0x7fa9bf840b19 in /usr/lib/libtorch_cpu.so)\nframe #9: at::native::nll_loss_symint(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, c10::SymInt) + 0xb7 (0x7fa9bedc9367 in /usr/lib/libtorch_cpu.so)\nframe #10: <unknown function> + 0x28ca4a6 (0x7fa9bfeca4a6 in /usr/lib/libtorch_cpu.so)\nframe #11: at::_ops::nll_loss::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, long, c10::SymInt) + 0x219 (0x7fa9bf9a8179 in /usr/lib/libtorch_cpu.so)\nframe #12: <unknown function> + 0x126040 (0x55a956179040 in ./target/release/dl-glass)\nframe #13: <unknown function> + 0x110ef4 (0x55a956163ef4 in ./target/release/dl-glass)\nframe #14: <unknown function> + 0x111045 (0x55a956164045 in ./target/release/dl-glass)\nframe #15: <unknown function> + 0x3c88d (0x55a95608f88d in ./target/release/dl-glass)\nframe #16: <unknown function> + 0x2ed93 (0x55a956081d93 in ./target/release/dl-glass)\nframe #17: <unknown function> + 0x837a9 (0x55a9560d67a9 in ./target/release/dl-glass)\nframe #18: <unknown function> + 0x2511dc (0x55a9562a41dc in ./target/release/dl-glass)\nframe #19: <unknown function> + 0x3d1e5 (0x55a9560901e5 in ./target/release/dl-glass)\nframe #20: <unknown function> + 0x23850 (0x7fa9bd2b2850 in /usr/lib/libc.so.6)\nframe #21: __libc_start_main + 0x8a (0x7fa9bd2b290a in /usr/lib/libc.so.6)\nframe #22: <unknown function> + 0x2c5c5 (0x55a95607f5c5 in ./target/release/dl-glass)\n")', /home/aliez/.cargo/registry/src/github.com-1ecc6299db9ec823/tch-0.13.0/src/wrappers/tensor_generated.rs:13275:66
stack backtrace:
   0:     0x55a9562aa16a - std::backtrace_rs::backtrace::libunwind::trace::ha271a8a7e1f3d4ef
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5
   1:     0x55a9562aa16a - std::backtrace_rs::backtrace::trace_unsynchronized::h85739da0352c791a
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5
   2:     0x55a9562aa16a - std::sys_common::backtrace::_print_fmt::hbc6ebcfb2910b329
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/sys_common/backtrace.rs:65:5
   3:     0x55a9562aa16a - <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt::he1c117e52d53614f
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/sys_common/backtrace.rs:44:22
   4:     0x55a9562cb36e - core::fmt::write::h25eb51b9526b8e0c
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/core/src/fmt/mod.rs:1213:17
   5:     0x55a9562a7c15 - std::io::Write::write_fmt::ha9edec5fb1621933
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/io/mod.rs:1682:15
   6:     0x55a9562a9f35 - std::sys_common::backtrace::_print::hf8657cd429fc3452
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/sys_common/backtrace.rs:47:5
   7:     0x55a9562a9f35 - std::sys_common::backtrace::print::h41b9b18ed86f86bd
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/sys_common/backtrace.rs:34:9
   8:     0x55a9562ab71f - std::panicking::default_hook::{{closure}}::h22a91871f4454152
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:267:22
   9:     0x55a9562ab45b - std::panicking::default_hook::h21ddc36de0cd4ae7
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:286:9
  10:     0x55a9562abe29 - std::panicking::rust_panic_with_hook::h5059419d6d59b3d0
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:688:13
  11:     0x55a9562abbc9 - std::panicking::begin_panic_handler::{{closure}}::h0f383c291cd78343
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:579:13
  12:     0x55a9562aa61c - std::sys_common::backtrace::__rust_end_short_backtrace::h70ab22f2ad318cdd
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/sys_common/backtrace.rs:137:18
  13:     0x55a9562ab8d2 - rust_begin_unwind
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:575:5
  14:     0x55a95607ef83 - core::panicking::panic_fmt::hd1d46bcde3c61d72
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/core/src/panicking.rs:64:14
  15:     0x55a95607f433 - core::result::unwrap_failed::h456a23f68607268c
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/core/src/result.rs:1790:5
  16:     0x55a956163fc5 - tch::tensor::<impl tch::wrappers::tensor::Tensor>::nll_loss::h58fd748ab11dd56e
  17:     0x55a956164045 - tch::tensor::<impl tch::wrappers::tensor::Tensor>::cross_entropy_for_logits::h24a191fa35112726
  18:     0x55a95608f88d - dl_glass::main::h042aa9d4dc08e08f
  19:     0x55a956081d93 - std::sys_common::backtrace::__rust_begin_short_backtrace::h7950d2f97058a47a
  20:     0x55a9560d67a9 - std::rt::lang_start::{{closure}}::hdec7311f5a04322c
  21:     0x55a9562a41dc - core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once::h203afb3af230319a
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/core/src/ops/function.rs:287:13
  22:     0x55a9562a41dc - std::panicking::try::do_call::hf68e87013b70f3c5
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:483:40
  23:     0x55a9562a41dc - std::panicking::try::h040ea8f298390ba2
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:447:19
  24:     0x55a9562a41dc - std::panic::catch_unwind::h1e17b198887a05fa
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panic.rs:140:14
  25:     0x55a9562a41dc - std::rt::lang_start_internal::{{closure}}::hfb902d8927e51b86
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/rt.rs:148:48
  26:     0x55a9562a41dc - std::panicking::try::do_call::h354e6eb41f2e7d42
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:483:40
  27:     0x55a9562a41dc - std::panicking::try::h4a39749cd018228c
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panicking.rs:447:19
  28:     0x55a9562a41dc - std::panic::catch_unwind::h30bce83b8de61cca
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/panic.rs:140:14
  29:     0x55a9562a41dc - std::rt::lang_start_internal::h8f7e70b1a2558118
                               at /rustc/9eb3afe9ebe9c7d2b84b71002d44f4a0edac95e0/library/std/src/rt.rs:148:20
  30:     0x55a9560901e5 - main
  31:     0x7fa9bd2b2850 - <unknown>
  32:     0x7fa9bd2b290a - __libc_start_main
  33:     0x55a95607f5c5 - _start
  34:                0x0 - <unknown>

I have double-checked the shapes of batch_images and batch_lbls, which are [128, 262144] and [128], respectively.


Solution

  • Following what happens in your forward function, the issue comes from the .view([-1, 1024]) line. Since your input after the first view in forward_t is of shape [B, 1, 512, 512], the convolutions and pooling layers are taking it to the dimension [B, 64, 125, 125], where B denotes the batch dimension (128 in your case). But then you are viewing it with target shape [-1, 1024]. -1 in view means that this value will be calculated depending on the input tensor, in this case it will be B*64*125*125/1024, which is in your case 125000, and you are effectively losing your batch dimension.

    After that, linear layers are just considering your input as a 125000-sized batch of 1024-sized vectors and are able to process this data, giving you tensors of shape [125000, 5], which are incompatible to compute cross-entropy with a reference of shape 128.

    To fix this issue, you must keep your batch dimension across all the network, by changing this 1024 value in the second view to 1 000 000 and in the input of the following linear layer for example. If you want to keep 1024 here, you can adapt the settings of the previous convolution and pooling layers to effectively have B*1024 features before the view.