I am trying to use a databricks notebook to finetune the Llama2 model. The code for this is here. I'm running into an error at lines 219-231:
from trl import SFTTrainer
max_seq_length = 512
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)
I am getting the error
ImportError: cannot import name 'override' from 'typing_extensions' (/databricks/python/lib/python3.10/site-packages/typing_extensions.py)
The full stack trace is below.
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
File <command-3349581073491723>, line 1
----> 1 from trl import SFTTrainer
3 max_seq_length = 512
5 trainer = SFTTrainer(
6 model=model,
7 train_dataset=dataset,
(...)
12 args=training_arguments,
13 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/trl/__init__.py:15
8 from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
9 from .models import (
10 AutoModelForCausalLMWithValueHead,
11 AutoModelForSeq2SeqLMWithValueHead,
12 PreTrainedModelWrapper,
13 create_reference_model,
14 )
---> 15 from .trainer import (
16 DataCollatorForCompletionOnlyLM,
17 DPOTrainer,
18 IterativeSFTTrainer,
19 PPOConfig,
20 PPOTrainer,
21 RewardConfig,
22 RewardTrainer,
23 SFTTrainer,
24 )
27 if is_diffusers_available():
28 from .models import (
29 DDPOPipelineOutput,
30 DDPOSchedulerOutput,
31 DDPOStableDiffusionPipeline,
32 DefaultDDPOStableDiffusionPipeline,
33 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/trl/trainer/__init__.py:40
38 from .dpo_trainer import DPOTrainer
39 from .iterative_sft_trainer import IterativeSFTTrainer
---> 40 from .ppo_config import PPOConfig
41 from .ppo_trainer import PPOTrainer
42 from .reward_trainer import RewardTrainer, compute_accuracy
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/trl/trainer/ppo_config.py:22
19 from typing import Literal, Optional
21 import numpy as np
---> 22 import tyro
23 from typing_extensions import Annotated
25 from trl.trainer.utils import exact_div
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/tyro/__init__.py:4
1 from typing import TYPE_CHECKING
3 from . import conf as conf
----> 4 from . import extras as extras
5 from ._cli import cli as cli
6 from ._fields import MISSING as MISSING
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/tyro/extras/__init__.py:5
1 """The :mod:`tyro.extras` submodule contains helpers that complement :func:`tyro.cli()`.
2
3 Compared to the core interface, APIs here are more likely to be changed or deprecated. """
----> 5 from .._argparse_formatter import set_accent_color as set_accent_color
6 from .._cli import get_parser as get_parser
7 from ._base_configs import (
8 subcommand_type_from_defaults as subcommand_type_from_defaults,
9 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-8a237d14-25f4-4066-b47f-e8a95f2342d9/lib/python3.10/site-packages/tyro/_argparse_formatter.py:37
35 from rich.text import Text
36 from rich.theme import Theme
---> 37 from typing_extensions import override
39 from . import _arguments, _strings, conf
40 from ._parsers import ParserSpecification
ImportError: cannot import name 'override' from 'typing_extensions' (/databricks/python/lib/python3.10/site-packages/typing_extensions.py)
I have tried installing multiple versions of typing_extensions, both the most updated version (4.8.0), as well as (4.7.1) as was suggested in this stackoverflow post. I also tried the solution posted here, as well as installing the dependencies with '%' instead of '!', as suggested here. None of this has worked.
Here's a full list of my installed packages:
Package Version
---------------------------- -------------
absl-py 1.0.0
accelerate 0.25.0.dev0
aiohttp 3.8.5
aiosignal 1.3.1
anyio 3.5.0
appdirs 1.4.4
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
astor 0.8.1
asttokens 2.2.1
astunparse 1.6.3
async-timeout 4.0.3
attrs 21.4.0
audioread 3.0.0
azure-core 1.29.1
azure-cosmos 4.3.1
azure-storage-blob 12.17.0
azure-storage-file-datalake 12.12.0
backcall 0.2.0
bcrypt 3.2.0
beautifulsoup4 4.11.1
bitsandbytes 0.41.2.post2
black 22.6.0
bleach 4.1.0
blinker 1.4
blis 0.7.10
boto3 1.24.28
botocore 1.27.28
cachetools 4.2.4
catalogue 2.0.9
category-encoders 2.6.1
certifi 2022.9.14
cffi 1.15.1
chardet 4.0.0
charset-normalizer 2.0.4
click 8.0.4
cloudpickle 2.0.0
cmdstanpy 1.1.0
confection 0.1.1
configparser 5.2.0
convertdate 2.4.0
cryptography 37.0.1
cycler 0.11.0
cymem 2.0.7
Cython 0.29.32
dacite 1.8.1
databricks-automl-runtime 0.2.17
databricks-cli 0.17.7
databricks-feature-store 0.14.1
databricks-sdk 0.1.6
dataclasses-json 0.5.14
datasets 2.13.1
dbl-tempo 0.1.23
dbus-python 1.2.18
debugpy 1.6.0
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.4
diskcache 5.6.1
distlib 0.3.7
distro 1.7.0
distro-info 1.1+ubuntu0.1
docstring-parser 0.15
docstring-to-markdown 0.12
einops 0.6.1
entrypoints 0.4
ephem 4.1.4
evaluate 0.4.0
executing 1.2.0
facets-overview 1.0.3
fastapi 0.98.0
fastjsonschema 2.18.0
fasttext 0.9.2
filelock 3.6.0
flash-attn 1.0.7
Flask 1.1.2+db1
flatbuffers 23.5.26
fonttools 4.25.0
frozenlist 1.4.0
fsspec 2022.7.1
future 0.18.2
gast 0.4.0
gitdb 4.0.10
GitPython 3.1.27
google-api-core 2.8.2
google-auth 1.33.0
google-auth-oauthlib 0.4.6
google-cloud-core 2.3.3
google-cloud-storage 2.10.0
google-crc32c 1.5.0
google-pasta 0.2.0
google-resumable-media 2.5.0
googleapis-common-protos 1.56.4
greenlet 1.1.1
grpcio 1.48.1
grpcio-status 1.48.1
gunicorn 20.1.0
gviz-api 1.10.0
h11 0.14.0
h5py 3.7.0
holidays 0.27.1
horovod 0.28.1
htmlmin 0.1.12
httplib2 0.20.2
httptools 0.6.0
huggingface-hub 0.16.4
idna 3.3
ImageHash 4.3.1
imbalanced-learn 0.10.1
importlib-metadata 4.11.3
importlib-resources 6.0.1
ipykernel 6.17.1
ipython 8.10.0
ipython-genutils 0.2.0
ipywidgets 7.7.2
isodate 0.6.1
itsdangerous 2.0.1
jedi 0.18.1
jeepney 0.7.1
Jinja2 2.11.3
jmespath 0.10.0
joblib 1.2.0
joblibspark 0.5.1
jsonschema 4.16.0
jupyter-client 7.3.4
jupyter_core 4.11.2
jupyterlab-pygments 0.1.2
jupyterlab-widgets 1.0.0
keras 2.11.0
keyring 23.5.0
kiwisolver 1.4.2
langchain 0.0.217
langchainplus-sdk 0.0.20
langcodes 3.3.0
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
lazy_loader 0.3
libclang 15.0.6.1
librosa 0.10.0
lightgbm 3.3.5
llvmlite 0.38.0
LunarCalendar 0.0.9
Mako 1.2.0
Markdown 3.3.4
markdown-it-py 3.0.0
MarkupSafe 2.0.1
marshmallow 3.20.1
matplotlib 3.5.2
matplotlib-inline 0.1.6
mccabe 0.7.0
mdurl 0.1.2
mistune 0.8.4
mleap 0.20.0
mlflow-skinny 2.5.0
more-itertools 8.10.0
msgpack 1.0.5
multidict 6.0.4
multimethod 1.9.1
multiprocess 0.70.12.2
murmurhash 1.0.9
mypy-extensions 0.4.3
nbclient 0.5.13
nbconvert 6.4.4
nbformat 5.5.0
nest-asyncio 1.5.5
networkx 2.8.4
ninja 1.11.1
nltk 3.7
nodeenv 1.8.0
notebook 6.4.12
numba 0.55.1
numexpr 2.8.4
numpy 1.21.5
oauthlib 3.2.0
openai 0.27.8
openapi-schema-pydantic 1.2.4
opt-einsum 3.3.0
packaging 21.3
pandas 1.4.4
pandocfilters 1.5.0
paramiko 2.9.2
parso 0.8.3
pathspec 0.9.0
pathy 0.10.2
patsy 0.5.2
peft 0.4.0
petastorm 0.12.1
pexpect 4.8.0
phik 0.12.3
pickleshare 0.7.5
Pillow 9.2.0
pip 23.3.1
platformdirs 2.5.2
plotly 5.9.0
pluggy 1.0.0
pmdarima 2.0.3
pooch 1.7.0
preshed 3.0.8
prompt-toolkit 3.0.36
prophet 1.1.4
protobuf 3.19.4
psutil 5.9.0
psycopg2 2.9.3
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 8.0.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pybind11 2.11.1
pycparser 2.21
pydantic 1.10.6
pyflakes 3.0.1
Pygments 2.16.1
PyGObject 3.42.1
PyJWT 2.3.0
PyMeeus 0.5.12
PyNaCl 1.5.0
pyodbc 4.0.32
pyparsing 3.0.9
pyright 1.1.294
pyrsistent 0.18.0
pytesseract 0.3.10
python-apt 2.4.0+ubuntu2
python-dateutil 2.8.2
python-dotenv 1.0.0
python-editor 1.0.4
python-lsp-jsonrpc 1.0.0
python-lsp-server 1.7.1
pytoolconfig 1.2.2
pytz 2022.1
PyWavelets 1.3.0
PyYAML 6.0
pyzmq 23.2.0
regex 2022.7.9
requests 2.28.1
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.6.0
rope 1.7.0
rsa 4.9
s3transfer 0.6.0
safetensors 0.3.2
scikit-learn 1.1.1
scipy 1.9.1
seaborn 0.11.2
SecretStorage 3.3.1
Send2Trash 1.8.0
sentence-transformers 2.2.2
sentencepiece 0.1.99
setuptools 63.4.1
shap 0.41.0
shtab 1.6.4
simplejson 3.17.6
six 1.16.0
slicer 0.0.7
smart-open 5.2.1
smmap 5.0.0
sniffio 1.2.0
soundfile 0.12.1
soupsieve 2.3.1
soxr 0.3.6
spacy 3.5.3
spacy-legacy 3.0.12
spacy-loggers 1.0.4
spark-tensorflow-distributor 1.0.0
SQLAlchemy 1.4.39
sqlparse 0.4.2
srsly 2.4.7
ssh-import-id 5.11
stack-data 0.6.2
starlette 0.27.0
statsmodels 0.13.2
tabulate 0.8.10
tangled-up-in-unicode 0.2.0
tenacity 8.1.0
tensorboard 2.11.0
tensorboard-data-server 0.6.1
tensorboard-plugin-profile 2.11.2
tensorboard-plugin-wit 1.8.1
tensorflow 2.11.1
tensorflow-estimator 2.11.0
tensorflow-io-gcs-filesystem 0.33.0
termcolor 2.3.0
terminado 0.13.1
testpath 0.6.0
thinc 8.1.12
threadpoolctl 2.2.0
tiktoken 0.4.0
tokenize-rt 4.2.1
tokenizers 0.13.3
tomli 2.0.1
torch 1.13.1+cu117
torchvision 0.14.1+cu117
tornado 6.1
tqdm 4.64.1
traitlets 5.1.1
transformers 4.30.2
trl 0.7.4
typeguard 2.13.3
typer 0.7.0
typing_extensions 4.7.1
typing-inspect 0.9.0
tyro 0.5.14
ujson 5.4.0
unattended-upgrades 0.1
urllib3 1.26.11
uvicorn 0.23.2
uvloop 0.17.0
virtualenv 20.16.3
visions 0.7.5
wadllib 1.3.6
wasabi 1.1.2
watchfiles 0.19.0
wcwidth 0.2.5
webencodings 0.5.1
websocket-client 0.58.0
websockets 11.0.3
Werkzeug 2.0.3
whatthepatch 1.0.2
wheel 0.37.1
widgetsnbextension 3.6.1
wordcloud 1.9.2
wrapt 1.14.1
xgboost 1.7.6
xxhash 3.3.0
yapf 0.31.0
yarl 1.9.2
ydata-profiling 4.2.0
zipp 3.8.0
If anyone has any idea how to fix this, please let me know!
might related to Databricks bug. see this, https://github.com/openai/openai-python/issues/751