Search code examples
pythonsimpletransformers

simpletransforers use_cuda=True not working


wanted to try CUDA (I have an RTX 3070 TI) on my Windows setup, using this code:

import pandas as pd 
from simpletransformers.classification import ClassificationModel
from sklearn.model_selection import train_test_split 
from sklearn import preprocessing 

# Read data from JSON 
df = pd.read_json(r"C:\Users\NLP and Computer Vision\Coding\News_Category_Dataset_v3.json", orient="records", lines=True)

print(df[df["headline"].isna() | df["short_description"].isna()])
print(df.head())


data = pd.DataFrame() 
data["text"] = df.headline + " " + df.short_description 
data["labels"] = df.category
 
labels = list(data["labels"].unique()) 

# Convert labels to numerical values 
le = preprocessing.LabelEncoder() 
le.fit(labels)
data["labels"] = le.transform(data["labels"])

train_df, eval_df = train_test_split(data, test_size=0.2) 

# Create a classification model 
model = ClassificationModel('bert', 'bert-base-uncased', num_labels=len(labels), use_cuda=False) 

# Train the model 
model.train_model(train_df)
 
# Evaluate the model 
result, model_outputs, predictions = model.eval_model(eval_df) 

It is running when I set use_cuda=False, when I enable it with True it's not working.

I installed the NVIDIA CUDA toolkit for Windows and on Anaconda. What am I doing wrong?


Solution

  • I don't know what went wrong with the environment, but that was the issue.

    I solved it using this env setup as yaml file:

    name: nlp
    channels:
      - pytorch
      - intel
      - nvidia
      - conda-forge
      - defaults
    dependencies:
      - alabaster=0.7.12=pyhd3eb1b0_0
      - anyio=3.5.0=py310haa95532_0
      - appdirs=1.4.4=pyhd3eb1b0_0
      - argon2-cffi=21.3.0=pyhd3eb1b0_0
      - argon2-cffi-bindings=21.2.0=py310h2bbff1b_0
      - arrow=1.2.3=py310haa95532_1
      - astroid=2.14.2=py310haa95532_0
      - asttokens=2.0.5=pyhd3eb1b0_0
      - atomicwrites=1.4.0=py_0
      - attrs=22.1.0=py310haa95532_0
      - autopep8=1.6.0=pyhd3eb1b0_1
      - babel=2.11.0=py310haa95532_0
      - backcall=0.2.0=pyhd3eb1b0_0
      - bcrypt=3.2.0=py310h2bbff1b_1
      - beautifulsoup4=4.12.2=py310haa95532_0
      - binaryornot=0.4.4=pyhd3eb1b0_1
      - black=23.3.0=py310haa95532_0
      - blas=1.0=mkl
      - bleach=4.1.0=pyhd3eb1b0_0
      - bottleneck=1.3.5=py310h9128911_0
      - brotlipy=0.7.0=py310h2bbff1b_1002
      - bzip2=1.0.8=he774522_0
      - ca-certificates=2022.12.7=h5b45459_0
      - certifi=2022.12.7=pyhd8ed1ab_0
      - cffi=1.15.1=py310h2bbff1b_3
      - chardet=4.0.0=py310haa95532_1003
      - charset-normalizer=2.0.4=pyhd3eb1b0_0
      - click=8.0.4=py310haa95532_0
      - cloudpickle=2.0.0=pyhd3eb1b0_0
      - colorama=0.4.6=py310haa95532_0
      - comm=0.1.2=py310haa95532_0
      - cookiecutter=1.7.3=pyhd3eb1b0_0
      - cryptography=39.0.1=py310h21b164f_0
      - cuda-cccl=12.1.109=0
      - cuda-cudart=11.7.99=0
      - cuda-cudart-dev=11.7.99=0
      - cuda-cupti=11.7.101=0
      - cuda-libraries=11.7.1=0
      - cuda-libraries-dev=11.7.1=0
      - cuda-nvrtc=11.7.99=0
      - cuda-nvrtc-dev=11.7.99=0
      - cuda-nvtx=11.7.91=0
      - cuda-runtime=11.7.1=0
      - debugpy=1.5.1=py310hd77b12b_0
      - decorator=5.1.1=pyhd3eb1b0_0
      - defusedxml=0.7.1=pyhd3eb1b0_0
      - diff-match-patch=20200713=pyhd3eb1b0_0
      - dill=0.3.6=py310haa95532_0
      - docstring-to-markdown=0.11=py310haa95532_0
      - docutils=0.18.1=py310haa95532_3
      - entrypoints=0.4=py310haa95532_0
      - executing=0.8.3=pyhd3eb1b0_0
      - filelock=3.9.0=py310haa95532_0
      - flake8=6.0.0=py310haa95532_0
      - freetype=2.12.1=ha860e81_0
      - giflib=5.2.1=h8cc25b3_3
      - glib=2.69.1=h5dc1a3c_2
      - gst-plugins-base=1.18.5=h9e645db_0
      - gstreamer=1.18.5=hd78058f_0
      - icc_rt=2022.1.0=h6049295_2
      - icu=58.2=ha925a31_3
      - idna=3.4=py310haa95532_0
      - imagesize=1.4.1=py310haa95532_0
      - importlib-metadata=6.0.0=py310haa95532_0
      - importlib_metadata=6.0.0=hd3eb1b0_0
      - inflection=0.5.1=py310haa95532_0
      - intel-openmp=2021.4.0=haa95532_3556
      - intervaltree=3.1.0=pyhd3eb1b0_0
      - ipykernel=6.19.2=py310h9909e9c_0
      - ipython=8.12.0=py310haa95532_0
      - ipython_genutils=0.2.0=pyhd3eb1b0_1
      - ipywidgets=8.0.6=pyhd8ed1ab_0
      - isort=5.9.3=pyhd3eb1b0_0
      - jaraco.classes=3.2.1=pyhd3eb1b0_0
      - jedi=0.18.1=py310haa95532_1
      - jellyfish=0.9.0=py310h2bbff1b_0
      - jinja2=3.1.2=py310haa95532_0
      - jinja2-time=0.2.0=pyhd3eb1b0_3
      - joblib=1.2.0=pyh3f38642_0
      - jpeg=9e=h2bbff1b_1
      - jsonschema=4.17.3=py310haa95532_0
      - jupyter_client=8.1.0=py310haa95532_0
      - jupyter_core=5.3.0=py310haa95532_0
      - jupyter_server=1.23.4=py310haa95532_0
      - jupyterlab_pygments=0.1.2=py_0
      - jupyterlab_widgets=3.0.7=pyhd8ed1ab_0
      - keyring=23.13.1=py310haa95532_0
      - krb5=1.19.4=h5b6d351_0
      - lazy-object-proxy=1.6.0=py310h2bbff1b_0
      - lerc=3.0=hd77b12b_0
      - libclang=14.0.6=default_hb5a9fac_1
      - libclang13=14.0.6=default_h8e68704_1
      - libcublas=11.10.3.66=0
      - libcublas-dev=11.10.3.66=0
      - libcufft=10.7.2.124=0
      - libcufft-dev=10.7.2.124=0
      - libcurand=10.3.2.106=0
      - libcurand-dev=10.3.2.106=0
      - libcusolver=11.4.0.1=0
      - libcusolver-dev=11.4.0.1=0
      - libcusparse=11.7.4.91=0
      - libcusparse-dev=11.7.4.91=0
      - libdeflate=1.17=h2bbff1b_0
      - libffi=3.4.2=hd77b12b_6
      - libiconv=1.16=h2bbff1b_2
      - libnpp=11.7.4.75=0
      - libnpp-dev=11.7.4.75=0
      - libnvjpeg=11.8.0.2=0
      - libnvjpeg-dev=11.8.0.2=0
      - libogg=1.3.5=h2bbff1b_1
      - libpng=1.6.39=h8cc25b3_0
      - libsodium=1.0.18=h62dcd97_0
      - libspatialindex=1.9.3=h6c2663c_0
      - libtiff=4.5.0=h6c2663c_2
      - libuv=1.44.2=h2bbff1b_0
      - libvorbis=1.3.7=he774522_0
      - libwebp=1.2.4=hbc33d0d_1
      - libwebp-base=1.2.4=h2bbff1b_1
      - libxml2=2.10.3=h0ad7f3c_0
      - libxslt=1.1.37=h2bbff1b_0
      - lxml=4.9.2=py310h2bbff1b_0
      - lz4-c=1.9.4=h2bbff1b_0
      - markupsafe=2.1.1=py310h2bbff1b_0
      - matplotlib-inline=0.1.6=py310haa95532_0
      - mccabe=0.7.0=pyhd3eb1b0_0
      - mistune=0.8.4=py310h2bbff1b_1000
      - mkl=2021.4.0=haa95532_640
      - mkl-service=2.4.0=py310h2bbff1b_0
      - mkl_fft=1.3.1=py310ha0764ea_0
      - mkl_random=1.2.2=py310h4ed8f06_0
      - more-itertools=8.12.0=pyhd3eb1b0_0
      - mpmath=1.2.1=py310haa95532_0
      - mypy_extensions=0.4.3=py310haa95532_1
      - nbclassic=0.5.5=py310haa95532_0
      - nbclient=0.5.13=py310haa95532_0
      - nbconvert=6.5.4=py310haa95532_0
      - nbformat=5.7.0=py310haa95532_0
      - nest-asyncio=1.5.6=py310haa95532_0
      - networkx=2.8.4=py310haa95532_1
      - notebook=6.5.4=py310haa95532_0
      - notebook-shim=0.2.2=py310haa95532_0
      - numexpr=2.8.4=py310hd213c9f_0
      - numpy=1.24.3=py310hdc03b94_0
      - numpy-base=1.24.3=py310h3caf3d7_0
      - numpydoc=1.5.0=py310haa95532_0
      - openssl=1.1.1t=h2bbff1b_0
      - packaging=23.0=py310haa95532_0
      - pandas=1.5.3=py310h4ed8f06_0
      - pandocfilters=1.5.0=pyhd3eb1b0_0
      - paramiko=2.8.1=pyhd3eb1b0_0
      - parso=0.8.3=pyhd3eb1b0_0
      - pathspec=0.10.3=py310haa95532_0
      - pcre=8.45=hd77b12b_0
      - pexpect=4.8.0=pyhd3eb1b0_3
      - pickleshare=0.7.5=pyhd3eb1b0_1003
      - pillow=9.4.0=py310hd77b12b_0
      - pip=23.0.1=py310haa95532_0
      - platformdirs=2.5.2=py310haa95532_0
      - pluggy=1.0.0=py310haa95532_1
      - ply=3.11=py310haa95532_0
      - pooch=1.4.0=pyhd3eb1b0_0
      - poyo=0.5.0=pyhd3eb1b0_0
      - prometheus_client=0.14.1=py310haa95532_0
      - prompt-toolkit=3.0.36=py310haa95532_0
      - psutil=5.9.0=py310h2bbff1b_0
      - ptyprocess=0.7.0=pyhd3eb1b0_2
      - pure_eval=0.2.2=pyhd3eb1b0_0
      - pycodestyle=2.10.0=py310haa95532_0
      - pycparser=2.21=pyhd3eb1b0_0
      - pydocstyle=6.3.0=py310haa95532_0
      - pyflakes=3.0.1=py310haa95532_0
      - pylint=2.16.2=py310haa95532_0
      - pyls-spyder=0.4.0=pyhd3eb1b0_0
      - pynacl=1.5.0=py310h8cc25b3_0
      - pyopenssl=23.0.0=py310haa95532_0
      - pyqt=5.15.7=py310hd77b12b_0
      - pyqt5-sip=12.11.0=py310hd77b12b_0
      - pyqtwebengine=5.15.7=py310hd77b12b_0
      - pyrsistent=0.18.0=py310h2bbff1b_0
      - pysocks=1.7.1=py310haa95532_0
      - python=3.10.11=h966fe2a_2
      - python-dateutil=2.8.2=pyhd3eb1b0_0
      - python-fastjsonschema=2.16.2=py310haa95532_0
      - python-lsp-black=1.2.1=py310haa95532_0
      - python-lsp-jsonrpc=1.0.0=pyhd3eb1b0_0
      - python-lsp-server=1.7.2=py310haa95532_0
      - python-slugify=5.0.2=pyhd3eb1b0_0
      - pytoolconfig=1.2.5=py310haa95532_1
      - pytorch=2.0.0=py3.10_cuda11.7_cudnn8_0
      - pytorch-cuda=11.7=h16d0643_3
      - pytorch-mutex=1.0=cuda
      - pytz=2022.7=py310haa95532_0
      - pywin32=305=py310h2bbff1b_0
      - pywin32-ctypes=0.2.0=py310haa95532_1000
      - pywinpty=2.0.10=py310h5da7b33_0
      - pyyaml=6.0=py310h2bbff1b_1
      - pyzmq=25.0.2=py310hd77b12b_0
      - qdarkstyle=3.0.2=pyhd3eb1b0_0
      - qt-main=5.15.2=he8e5bd7_8
      - qt-webengine=5.15.9=hb9a9bb5_5
      - qtconsole=5.4.2=py310haa95532_0
      - qtpy=2.2.0=py310haa95532_0
      - qtwebkit=5.212=h2bbfb41_5
      - requests=2.29.0=py310haa95532_0
      - rope=1.7.0=py310haa95532_0
      - rtree=1.0.1=py310h2eaa2aa_0
      - scikit-learn=1.2.1=py310hd77b12b_0
      - scipy=1.10.1=py310hb9afe5d_0
      - send2trash=1.8.0=pyhd3eb1b0_1
      - setuptools=66.0.0=py310haa95532_0
      - sip=6.6.2=py310hd77b12b_0
      - six=1.16.0=pyhd3eb1b0_1
      - sniffio=1.2.0=py310haa95532_1
      - snowballstemmer=2.2.0=pyhd3eb1b0_0
      - sortedcontainers=2.4.0=pyhd3eb1b0_0
      - soupsieve=2.4=py310haa95532_0
      - sphinx=5.0.2=py310haa95532_0
      - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0
      - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0
      - sphinxcontrib-htmlhelp=2.0.0=pyhd3eb1b0_0
      - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0
      - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0
      - sphinxcontrib-serializinghtml=1.1.5=pyhd3eb1b0_0
      - spyder=5.4.3=py310haa95532_1
      - spyder-kernels=2.4.3=py310haa95532_0
      - sqlite=3.41.2=h2bbff1b_0
      - stack_data=0.2.0=pyhd3eb1b0_0
      - sympy=1.11.1=py310haa95532_0
      - terminado=0.17.1=py310haa95532_0
      - text-unidecode=1.3=pyhd3eb1b0_0
      - textdistance=4.2.1=pyhd3eb1b0_0
      - threadpoolctl=2.2.0=pyh0d69192_0
      - three-merge=0.1.1=pyhd3eb1b0_0
      - tinycss2=1.2.1=py310haa95532_0
      - tk=8.6.12=h2bbff1b_0
      - toml=0.10.2=pyhd3eb1b0_0
      - tomli=2.0.1=py310haa95532_0
      - tomlkit=0.11.1=py310haa95532_0
      - tornado=6.2=py310h2bbff1b_0
      - traitlets=5.7.1=py310haa95532_0
      - typing-extensions=4.5.0=py310haa95532_0
      - typing_extensions=4.5.0=py310haa95532_0
      - ujson=5.4.0=py310hd77b12b_0
      - unidecode=1.2.0=pyhd3eb1b0_0
      - urllib3=1.26.15=py310haa95532_0
      - vc=14.2=h21ff451_1
      - vs2015_runtime=14.27.29016=h5e58377_2
      - watchdog=2.1.6=py310haa95532_0
      - wcwidth=0.2.5=pyhd3eb1b0_0
      - webencodings=0.5.1=py310haa95532_1
      - websocket-client=0.58.0=py310haa95532_4
      - whatthepatch=1.0.2=py310haa95532_0
      - wheel=0.38.4=py310haa95532_0
      - widgetsnbextension=4.0.7=pyhd8ed1ab_0
      - win_inet_pton=1.1.0=py310haa95532_0
      - winpty=0.4.3=4
      - wrapt=1.14.1=py310h2bbff1b_0
      - xz=5.2.10=h8cc25b3_1
      - yaml=0.2.5=he774522_0
      - yapf=0.31.0=pyhd3eb1b0_0
      - zeromq=4.3.4=hd77b12b_0
      - zipp=3.11.0=py310haa95532_0
      - zlib=1.2.13=h8cc25b3_0
      - zstd=1.5.5=hd43e919_0
      - pip:
          - absl-py==1.4.0
          - aiohttp==3.8.4
          - aiosignal==1.3.1
          - altair==4.2.2
          - async-timeout==4.0.2
          - blinker==1.6.2
          - cachetools==5.3.0
          - datasets==2.12.0
          - docker-pycreds==0.4.0
          - frozenlist==1.3.3
          - fsspec==2023.4.0
          - gitdb==4.0.10
          - gitpython==3.1.31
          - google-auth==2.17.3
          - google-auth-oauthlib==1.0.0
          - grpcio==1.54.0
          - huggingface-hub==0.14.1
          - markdown==3.4.3
          - markdown-it-py==2.2.0
          - mdurl==0.1.2
          - multidict==6.0.4
          - multiprocess==0.70.14
          - oauthlib==3.2.2
          - pathtools==0.1.2
          - protobuf==3.20.3
          - pyarrow==11.0.0
          - pyasn1==0.5.0
          - pyasn1-modules==0.3.0
          - pydeck==0.8.1b0
          - pygments==2.15.1
          - pylint-venv==2.3.0
          - pympler==1.0.1
          - pytz-deprecation-shim==0.1.0.post0
          - qstylizer==0.2.2
          - qtawesome==1.2.2
          - regex==2023.3.23
          - requests-oauthlib==1.3.1
          - responses==0.18.0
          - rich==13.3.5
          - rsa==4.9
          - sentencepiece==0.1.98
          - sentry-sdk==1.21.1
          - seqeval==1.2.2
          - setproctitle==1.3.2
          - simpletransformers==0.63.11
          - smmap==5.0.0
          - streamlit==1.22.0
          - tenacity==8.2.2
          - tensorboard==2.12.2
          - tensorboard-data-server==0.7.0
          - tensorboard-plugin-wit==1.8.1
          - tokenizers==0.13.3
          - toolz==0.12.0
          - torchaudio==2.0.0
          - torchvision==0.15.0
          - tqdm==4.65.0
          - transformers==4.28.1
          - tzdata==2023.3
          - tzlocal==4.3
          - validators==0.20.0
          - wandb==0.15.0
          - werkzeug==2.3.2
          - xxhash==3.2.0
          - yarl==1.9.2