Search code examples
nlppipelinehuggingface-transformersbert-language-modelnamed-entity-recognition

Getting the input text from transformers pipeline


I am following the tutorial on https://huggingface.co/docs/transformers/pipeline_tutorial to use transformers pipeline for inference. For example, the following code snippet works for getting the NER results from ner pipeline.

    # KeyDataset is a util that will just output the item we're interested in.
    from transformers.pipelines.pt_utils import KeyDataset
    from datasets import load_dataset
    model = ...
    tokenizer = ...
    pipe = pipeline("ner", model=model, tokenizer=tokenizer)
    dataset = load_dataset("my_ner_dataset", split="test")
    
    for extracted_entities in pipe(KeyDataset(dataset, "text")):
        print(extracted_entities)

In NER, as well as many applications, we would like to also get the input so that I can store the result as (text, extracted_entities) pair for later processing. Basically I am looking for something like:

    # KeyDataset is a util that will just output the item we're interested in.
    from transformers.pipelines.pt_utils import KeyDataset
    from datasets import load_dataset
    model = ...
    tokenizer = ...
    pipe = pipeline("ner", model=model, tokenizer=tokenizer)
    dataset = load_dataset("my_ner_dataset", split="test")
    
    for text, extracted_entities in pipe(KeyDataset(dataset, "text")):
        print(text, extracted_entities)

Where text is the raw input text (possibly batched) that get fed into the pipeline.

Is this doable ?


Solution

  • Solution

    # Datasets 2.11.0
    from datasets import load_dataset
    # Transformers 4.27.4, Torch 2.0.0+cu118, 
    from transformers import (
        AutoTokenizer,
        AutoModelForTokenClassification,
        pipeline
    )
    from transformers.pipelines.pt_utils import KeyDataset
    
    model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
    tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
    
    pipe = pipeline(task="ner", model=model, tokenizer=tokenizer)
    dataset = load_dataset("argilla/gutenberg_spacy-ner", split="train")
    results = pipe(KeyDataset(dataset, "text"))
    
    for idx, extracted_entities in enumerate(results):
        print("Original text:\n{}".format(dataset[idx]["text"]))
        print("Extracted entities:")
        for entity in extracted_entities:
            print(entity)
    

    Example output

    Original text:
    Would I wish to send up my name now ? Again I declined , to the polite astonishment of the concierge , who evidently considered me a queer sort of a friend . He was called to his desk by a guest , who wished to ask questions , of course , and I waited where I was . At a quarter to eleven Herbert Bayliss emerged from the elevator . His appearance almost shocked me . Out late the night before ! He looked as if he had been out all night for many nights .
    Extracted entities:
    {'entity': 'B-PER', 'score': 0.9996532, 'index': 68, 'word': 'Herbert', 'start': 289, 'end': 296}
    {'entity': 'I-PER', 'score': 0.9996567, 'index': 69, 'word': 'Bay', 'start': 297, 'end': 300}
    {'entity': 'I-PER', 'score': 0.9991698, 'index': 70, 'word': '##lis', 'start': 300, 'end': 303}
    {'entity': 'I-PER', 'score': 0.96547437, 'index': 71, 'word': '##s', 'start': 303, 'end': 304}
    
    ...
    
    Original text:
    And you think our run will be better than five hundred and eighty ? '' `` It should be , unless there is a remarkable change . This ship makes over six hundred , day after day , in good weather . She should do at least six hundred by to-morrow noon , unless there is a sudden change , as I said . '' `` But six hundred would be -- it would be the high field , by Jove ! '' `` Anything over five hundred and ninety-four would be that . The numbers are very low to-night .
    Extracted entities:
    {'entity': 'B-MISC', 'score': 0.40225995, 'index': 90, 'word': 'Jo', 'start': 363, 'end': 365}
    

    Brief Explanation

    Each sample in the dataset created by the load_dataset call can be accessed using an index and the associated dictionary key.

    Calls to the pipeline object with a KeyDataset as input returns PipelineIterator object that is iterable. Hence, one can enumerate the PipelineIterator object to get both the result and the index for the particular result, and then use that index to retrieve the associated sample in the dataset.

    Detailed Explanation

    The Huggingface pipeline abstraction is a wrapper for all available pipelines. When one instantiates a pipeline object it will return the appropriate pipeline based on the task argument:

    pipe = pipeline(task="ner", model=model, tokenizer=tokenizer)
    

    Given that the NER task is specified, a TokenClassificationPipeline will be returned (side note: "ner" is an alias for "token-classification"). This pipeline (and all others) inherits the base class Pipeline. The Pipeline base class defines the __call__ function which the TokenClassificationPipeline class relies on whenever the instantiated pipeline is called.

    Once a pipeline is instantiated (see above), it is called with data passed in as either a single string, a list, or when working with full datasets, a Huggingface dataset via the transformers.pipelines.pt_utils KeyDataset class.

    dataset = load_dataset("argilla/gutenberg_spacy-ner", split="train")
    results = pipe(KeyDataset(dataset, "text"))  # pipeline call
    

    When the pipeline is called, it checks whether the data passed in is iterable, and then calls an appropriate function. For Huggingface Dataset objects, the get_iterator function is called which returns a PipelineIterator object. Given the known behaviour of iterator objects, one can enumerate the object to return a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over iterable. The values are the NER extractions for each sample in the dataset. Hence, the following produces the desired results:

    for idx, extracted_entities in enumerate(results):
        print("Original text:\n{}".format(dataset[idx]["text"]))
        print("Extracted entities:")
        for entity in extracted_entities:
            print(entity)