Search code examples
c++torchlibtorch

How to compare a torch::tensor shape against some other shapes?


I'm trying to compare a torch::Tensor's sizes agains something else, but it seems I'm doing it wrong.

I tried:

    auto t = torch::ones({ 3,3 }).sizes();
    std::cout << c10::IntArrayRef{ 3,3 } << std::endl;
    std::cout << (t.equals(c10::IntArrayRef{ 3,3 })) << std::endl;

which always returns false.

I also tried:

t == c10::IntArrayRef{ 3,3 };

which also returns false. Since IntArratRef doesn't own a storage itself, I tried:

c10::IntArrayRef x(std::vector<int>{ 3, 3 });

but it fails, saying:

Error   C2664   'c10::ArrayRef<int64_t>::ArrayRef(c10::ArrayRef<int64_t> &&)': cannot convert argument 1 from 'std::vector<T,std::allocator<int>>' to 'const T &'   

I'm currently comparing each dimension individually which is far from desirable not to mention cumbersome. What's wrong here?

Update:

This error shouldnt happen, if it happens for you heres an issue that was made in pytorch github repository at issue-43611.


Solution

  • #include <iostream>
    
    #include <ATen/ATen.h> // minimal tensor header for faster compilation
    
    using namespace std;
    
    int main() {
    
        at::IntArrayRef size1 = { 4, 5 };       // declare size (4 rows x 5 cols)
        at::IntArrayRef size2 = { 4, 6 };       // declare size (4 rows x 6 cols)
    
        at::Tensor t1 = at::zeros(size1);       // allocate a tensor of zeros
    
        // compare tensor sizes
        cout << (t1.sizes() == size1) << endl   // true
            << (t1.sizes() == size2) << endl;   // false
    
        cin.get();
        return 0;
    }