Search code examples
pythonopenai-whisper

How do I change the dataset when I Fine-Tuning the Whisper model?


I tried to fine-tune the Whisper model by referring to the article. If want to refer to the code, please look at the colab link.

All I want to do is change the common-voice dataset used in the article to my dataset.

I use a prepared common-voice dataset, it works very well. The common-voice dataset appears to use a pre-cached .arrow file. Therefore, it consumes very little memory.

image.png

Because of this, it is fast, and the whole process is handled well. But using my dataset does not work.(It consumes a lot of memory.)

Specifically, it takes a lot of time in the code below and does not work.

common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)

In my opinion, this is due to the raw data that was pre-cached. I import the dataset with the simple code below.

My code does not create an array cache file of voice files.

class DataLoader_AIHub:
    def __init__(self, rootPath):
        self.rootPath = rootPath

    def getData(self, max_files_to_load, startPoint=0):
        rootPath_audio = os.path.join(self.rootPath, 'audio')
        audioDirPaths = getDirList(rootPath_audio)

        total_files_loaded = 0

        data_list = []

        for audioDir in audioDirPaths:
            audioFileNames = getFileList(audioDir)
            audioFilePaths = [audioDir + '/' + str(item) for item in audioFileNames]
            labelFilePaths = [item.replace('/audio/','/label/').replace('.wav','.json') for item in audioFilePaths]
        
            for audioPath, labelPath in zip(audioFilePaths, labelFilePaths):
                jsonInfo = getJson(labelPath)

                if '(' in jsonInfo['발화정보']['stt']:
                    continue

                if startPoint > total_files_loaded:
                    total_files_loaded += 1
                    continue

                audio, sr = sf.read(audioPath)
                audioArray = audio.astype(np.float32)

                dict = {
                    'audio': {
                        'path': audioPath,
                        'array': audioArray,
                        'sampling_rate': sr
                    },
                    'sentence': re.sub('\r\n', '', jsonInfo['발화정보']['stt']),
                    'age': jsonInfo['녹음자정보']['age'],
                    'gender': jsonInfo['녹음자정보']['gender']
                }

                data_list.append(dict)

                total_files_loaded += 1

                if total_files_loaded >= max_files_to_load + startPoint: 
                    return Dataset.from_list(data_list)
                 
        return Dataset.from_list(data_list)

(It is a Korean dataset.)

Voice files (.wav) are sampled at 16 kHz, and audioArray refers to an array that has been decoded. The .arrow file is presumed to store these decoding arrays.

Am I doing something wrong?


Solution

  • If you are experiencing the same issue, please refer to the link.

    The solution is as follows.

    from datasets import Audio
    
    audio_dataset = Dataset.from_dict({
        "audio": ['audio_file_path_1.wav', 'audio_file_path_2.wav']}).cast_column("audio", Audio())
    

    Then audio_dataset has the following format.

    Dataset({
        features: ['audio'],
        num_rows: 2
    })
    

    Now you can use the code below here to check the contents.

    audio_dataset['audio']
    

    enter image description here

    This ['audio'] contains dict {'path', 'array', 'sampling_rate'} The important point here is that the array is generated 'on access', so it consumes very little memory.

    You can now assign ten of thousands of voice files to Dataset variable and train them at once.


    In addition, strictly speaking, it is inevitable to consume a lot of memory in the codelane below in the example above (due to feature extraction)

    common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)
    

    However, if use the code I suggested, Allocating a large amount of swap memory space in the operating system can improve the situation from a level where the program is unrunnable to a level where it can be executed without issues.