I am implementing a fraud detection approach for graph-based data, from this article: https://developer.nvidia.com/blog/optimizing-fraud-detection-in-financial-services-with-graph-neural-networks-and-nvidia-gpus/. At the "Training the GNN model" and "Using GNN embeddings for downstream tasks" parts, the article suggests translating a tabular dataset into a graph, then generate a 64-width embedding for each node, before joining them to the original tabular dataset on the respective node IDs. I have generated a 64-width node embedding tensor, however, I am unsure what to do next. I have thought about condensing the tensor into one dimension and appending them based on the node IDs, but I feel like that is not what the article is suggesting. I have also thought of adding the entire tensor into the cell, but I feel like it is not going to fit with most machine learning models. What should I do in this situation? I do apologize if this does not seem like the right place to ask the question, and will remove it if that is the case.
To understand what kind of embeddings the authors of this NVIDIA blog use to enrich the dataset, let's rephrase their approach:
nodes
, and transactions represent edges
. I've visualized an example below to make it clear what are nodes
and what are edges
(it also proves why I preferred hard sciences to art school).node embeddings
using a link prediction
task, that is to say they train a model which predicts the probability that a card ID
and a merchant
are connected. Note: This is the 64-dimensional embedding that was referred by OP.merchant
and given card ID
are connected, you only need one embedding per transaction, since every transaction only has 1 merchant and one card.64 features
- XGBoost cannot handle a single feature containing a vector of 64 dimensions, which is why 64 features containing a single number is the only valid option.Visualization of type of graph created. For the sake of simplicity, the visualization assumes that every person only owns one card. Since the dataset does not contain a transaction ID
and every row in the dataset contain one transaction, you can consider the row index
a transaction ID
.