I'm writing a pytest file to check if my machine learning libraries use the GPU. For Tensorflow I can check this with tf.config.list_physical_devices()
. For XGBoost I've so far checked it by looking at GPU utilization (nvdidia-smi
) while running my software. But how can I check this in a simple test? Something similar to the test I have for Tensorflow would do.
import pytest
import tensorflow as tf
import xgboost
# Marking all tests to be GPU dependent
pytestmark = pytest.mark.gpu
def test_tf_finds_gpu():
"""Check if Tensorflow finds the GPU."""
assert tf.config.list_physical_devices("GPU")
def test_xgb_finds_gpu():
"""Check if XGBoost finds the GPU."""
...
# What can I write here?
Note that tree_method="gpu_hist"
is deprecated and will stop / has stopped working since xgboost==2.0.0
. Histogram type and device are currently split into two parameters: tree_method
(an unfortunate overwriting of the existing parameter, but with a different set of permitted levels) and a new one called device
:
import numpy as np
import xgboost as xgb
xgb_model = xgb.XGBRegressor( # tree_method="gpu_hist" # deprecated
tree_method="hist",
device="cuda"
)
X = np.random.rand(50, 2)
y = np.random.randint(2, size=50)
xgb_model.fit(X, y)
xgb_model
Output when GPU access works correctly (no warnings):
XGBRegressor(base_score=None, booster=None, callbacks=None,
colsample_bylevel=None, colsample_bynode=None,
colsample_bytree=None, device='cuda',
early_stopping_rounds=None,
enable_categorical=False, eval_metric=None, feature_types=None,
gamma=None, grow_policy=None, importance_type=None,
interaction_constraints=None, learning_rate=None, max_bin=None,
max_cat_threshold=None, max_cat_to_onehot=None,
max_delta_step=None, max_depth=None, max_leaves=None,
min_child_weight=None, missing=nan, monotone_constraints=None,
multi_strategy=None, n_estimators=None, n_jobs=None,
num_parallel_tree=None, random_state=None, ...)
vs.
No GPU access - a warning that the device
argument has not been used:
[11:43:35] WARNING: ../src/learner.cc:767:
Parameters: { "device" } are not used.