Source code for rl4lms.data_pools.custom_text_generation_pools

from rl4lms.data_pools.text_generation_pool import TextGenPool, Sample
from rl4lms.data_pools.task_utils.totto import preprocess_utils
from datasets import load_dataset
from tqdm import tqdm
from nltk.tokenize import word_tokenize
import os
from urllib.request import urlretrieve
from pathlib import Path
import pandas
from collections import defaultdict
import zipfile
import json


[docs]class ToTTo(TextGenPool):
[docs] @classmethod def prepare(cls, split: str, representation: str = 'subtable', **args) -> 'TextGenPool': ds = load_dataset('totto') samples = [] split_id = ToTTo.gen_split_name(split) n_samples = len(ds[split_id]) for ix, item in tqdm(enumerate(ds[split_id]), desc="Loading ToTTo dataset", total=n_samples): table = item["table"] table_page_title = item["table_page_title"] table_section_title = item["table_section_title"] cell_indices = item["highlighted_cells"] # can potentially add additional targets for various stages of annotation here targets = item["sentence_annotations"]["final_sentence"] if split_id == "test": targets = [""] # empty references instead of none subtable = ( preprocess_utils.get_highlighted_subtable( table=table, cell_indices=cell_indices, with_heuristic_headers=True)) if representation == 'subtable': subtable_str = ( preprocess_utils.linearize_subtable( subtable=subtable, table_page_title=None, table_section_title=None)) subtable_metadata_str = ( preprocess_utils.linearize_subtable( subtable=subtable, table_page_title=table_page_title, table_section_title=table_section_title)) prompt = subtable_str + subtable_metadata_str elif representation == 'fulltable': full_table_str = preprocess_utils.linearize_full_table( table=table, cell_indices=cell_indices, table_page_title=None, table_section_title=None) full_table_metadata_str = ( preprocess_utils.linearize_full_table( table=table, cell_indices=cell_indices, table_page_title=table_page_title, table_section_title=table_section_title)) prompt = full_table_str + full_table_metadata_str else: raise NotImplementedError # reformat sentence annotations (to fit in official totto eval) reformatted_sent_annotations = [] n_refs = len(targets) for ref_ix in range(n_refs): annotation = { "original_sentence": item["sentence_annotations"]["original_sentence"][ref_ix], "sentence_after_deletion": item["sentence_annotations"]["sentence_after_deletion"][ref_ix], "sentence_after_ambiguity": item["sentence_annotations"]["sentence_after_ambiguity"][ref_ix], "final_sentence": item["sentence_annotations"]["final_sentence"][ref_ix], } reformatted_sent_annotations.append(annotation) item["sentence_annotations"] = reformatted_sent_annotations # change overlap_subset to bool type item["overlap_subset"] = True if item["overlap_subset"] == "True" else False sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt, references=targets, meta_data={ "raw_table": item } ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs] @staticmethod def gen_split_name(split: str): if split == "train": split_name = "train" elif split == "val": split_name = "validation" elif split == "test": split_name = "test" else: raise NotImplementedError return split_name
[docs]class CommonGen(TextGenPool):
[docs] @classmethod def prepare(cls, split: str, concept_separator_token: str = " ", concept_end_token=" ", prefix: str = "summarize: ") -> 'TextGenPool': ds = load_dataset("gem", "common_gen") samples = [] split_id = CommonGen.gen_split_name(split) for ix, item in enumerate(ds[split_id]): concepts = concept_separator_token.join(item["concepts"]) concepts = prefix + concepts concepts += concept_end_token if item["target"] == "": # just to avoid breaking of metric computation item["target"] = "empty reference" targets = [item["target"]] sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=concepts, references=targets, meta_data={ "concepts": item["concepts"] } ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs] @staticmethod def gen_split_name(split: str): if split == "train": split_name = "train" elif split == "val": split_name = "validation" elif split == "test": split_name = "test" else: raise NotImplementedError return split_name
[docs]class Xsum(TextGenPool):
[docs] @classmethod def prepare(cls, split: str, prompt_suffix: str = "TL;DR:"): dataset = load_dataset("gem", "xsum") dataset_split = dataset[split] samples = [] for ix, item in enumerate(dataset_split): sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=item["document"] + prompt_suffix, references=[item["target"]] ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs]class CNNDailyMail(TextGenPool):
[docs] @classmethod def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", truncate_article: int = None, max_size: int = None): dataset = load_dataset("cnn_dailymail", "3.0.0") dataset_split = CommonGen.gen_split_name(split) samples = [] for ix, item in tqdm(enumerate(dataset[dataset_split]), desc="Tokenizing dataset", total=len(dataset[dataset_split])): if truncate_article is not None: tokens = word_tokenize(item["article"]) tokens = tokens[:truncate_article] item["article"] = " ".join(tokens) sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt_prefix + item["article"] + prompt_suffix, references=[item["highlights"]] ) samples.append(sample) if max_size is not None and ix == (max_size-1): break pool_instance = cls(samples) return pool_instance
[docs]class IMDB(TextGenPool): """ IMDB Dataset for sentiment continuation task """
[docs] @classmethod def prepare(cls, split: str, seed: int): dataset = load_dataset("imdb", ignore_verifications=True) if split in ["train", "val"]: dataset_split = dataset["train"].shuffle(seed) train_ratio = 0.8 train_index = int(len(dataset_split) * train_ratio) dataset_split = dataset_split[:train_index] if split == "train" else dataset_split[train_index:] else: dataset_split = dataset[split].shuffle(seed) dataset_split = dataset_split[:5000] samples = [] for ix, text in enumerate(dataset_split["text"]): # here we consider 50% of tokens as prompt prompt_text = text.split(" ") prompt_text = " ".join(prompt_text[:int(len(prompt_text) * 0.5)]) sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt_text, references=[text] ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs]class IMDBForSeq2Seq(TextGenPool): """ IMDB Dataset in seq2seq format to train supervised generator """
[docs] @classmethod def prepare(cls, split: str, positive_ratio: int = 1.0): dataset = load_dataset("imdb") if split in ["train", "val"]: dataset_split = dataset["train"].shuffle() train_ratio = 0.8 train_index = int(len(dataset_split) * train_ratio) dataset_split = dataset_split[:train_index] if split == "train" else dataset_split[train_index:] else: # limit test to 5000 dataset_split = dataset[split].shuffle() dataset_split = dataset_split[:5000] samples = [] for ix, (text, label) in enumerate(zip(dataset_split["text"], dataset_split["label"])): # here we consider 50% of tokens as prompt and rest as references tokenized_text = text.split(" ") text_split_index = int(len(tokenized_text) * 0.5) prompt_text = " ".join(tokenized_text[:text_split_index]) ref_text = " ".join(tokenized_text[text_split_index:]) # add only positive examples for train set if split == "train" and label == 1 or split != "train": sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt_text, references=[ref_text], meta_data={ "reference": text } ) samples.append(sample) # truncate train split if split == "train": samples = samples[:int(len(samples) * positive_ratio)] pool_instance = cls(samples) return pool_instance
[docs]def download_file_using_url(url: str, dest_path: str): urlretrieve(url, dest_path)
[docs]class NarrativeQA(TextGenPool):
[docs] @classmethod def normalize_text(cls, text, strip: bool): # https: // github.com/allenai/unifiedqa/blob/7bf0653c6fb68a51019924fd4c51615155acbebe/tasks.py # L54-L58 text = text.lower() if strip: text = text.strip() text.replace("'", '') return text
[docs] @classmethod def prepare(cls, split: str): # URLs urls = { "train": "https://storage.googleapis.com/unifiedqa/data/narrativeqa/train.tsv", "val": "https://storage.googleapis.com/unifiedqa/data/narrativeqa/dev.tsv", "test": "https://storage.googleapis.com/unifiedqa/data/narrativeqa/test.tsv" } # destination path dest_base_path = os.path.join(Path.home(), "narrative_qa") os.makedirs(dest_base_path, exist_ok=True) # download the file split_path = os.path.join(dest_base_path, f"{split}.tsv") if not os.path.exists(split_path): download_file_using_url(urls[split], split_path) # load the split split_df = pandas.read_csv(split_path, sep='\t', header=None, encoding="utf-8") # group questions and answers prompts_and_answers = defaultdict(list) for _, (prompt, answer) in split_df.iterrows(): prompts_and_answers[prompt].append(answer) samples = [] for ix, (prompt, answers) in enumerate(prompts_and_answers.items()): sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=cls.normalize_text( prompt, False), references=[cls.normalize_text(answer, True) for answer in answers]) samples.append(sample) dp_instance = cls(samples) return dp_instance
[docs]class WMT(TextGenPool):
[docs] @classmethod def get_dataset(cls, wmt_id: str, source_language: str, target_language: str, split: str): try: language_pair = f"{source_language}-{target_language}" dataset = load_dataset(f"{wmt_id}", language_pair) except: language_pair = f"{target_language}-{source_language}" dataset = load_dataset(f"{wmt_id}", language_pair) dataset = dataset[split] return dataset
[docs] @classmethod def prepare(cls, wmt_id: str, split: str, source_language: str, target_language: str, prompt_suffix: str = "", prompt_prefix: str = "", ): dataset_split = CommonGen.gen_split_name(split) dataset = WMT.get_dataset( wmt_id, source_language, target_language, dataset_split) samples = [] for ix, item in tqdm(enumerate(dataset), desc="Preparing dataset", total=len(dataset)): prompt = prompt_prefix + \ item["translation"][source_language] + prompt_suffix sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt, references=[item["translation"][target_language]] ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs]class WMT14PreprocessedEnDe(TextGenPool):
[docs] @classmethod def get_dataset(cls, split: str): dataset = load_dataset("stas/wmt14-en-de-pre-processed")[split] return dataset
[docs] @classmethod def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ): dataset_split = CommonGen.gen_split_name(split) dataset = WMT14PreprocessedEnDe.get_dataset(dataset_split) samples = [] for ix, item in tqdm(enumerate(dataset), desc="Preparing dataset", total=len(dataset)): prompt = prompt_prefix + \ item["translation"]["en"] + prompt_suffix sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt, references=[item["translation"]["de"]] ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs]class WMT16NewsOnlyDatasetEnDe(TextGenPool):
[docs] @classmethod def get_dataset(cls, split: str): if split == "train": dataset = load_dataset("news_commentary", "de-en")[split] return dataset else: dataset = load_dataset("wmt16", "de-en")[split] return dataset
[docs] @classmethod def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ): dataset_split = CommonGen.gen_split_name(split) dataset = WMT16NewsOnlyDatasetEnDe.get_dataset(dataset_split) samples = [] for ix, item in tqdm(enumerate(dataset), desc="Preparing dataset", total=len(dataset)): prompt = prompt_prefix + \ item["translation"]["en"] + prompt_suffix sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt, references=[item["translation"]["de"]] ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs]class IWSLT2017EnDe(TextGenPool):
[docs] @classmethod def get_dataset(cls, split: str): dataset = load_dataset( "iwslt2017", "iwslt2017-de-en", ignore_verifications=True) return dataset[split]
[docs] @classmethod def prepare(cls, split: str, prompt_suffix: str = "", prompt_prefix: str = "", ): dataset_split = CommonGen.gen_split_name(split) dataset = IWSLT2017EnDe.get_dataset(dataset_split) samples = [] for ix, item in tqdm(enumerate(dataset), desc="Preparing dataset", total=len(dataset)): prompt = prompt_prefix + \ item["translation"]["en"] + prompt_suffix sample = Sample(id=f"{split}_{ix}", prompt_or_input_text=prompt, references=[item["translation"]["de"]] ) samples.append(sample) pool_instance = cls(samples) return pool_instance
[docs]class CRD3DialogueGeneration(TextGenPool): SOURCE_URL = "https://github.com/RevanthRameshkumar/CRD3/archive/refs/heads/master.zip" DEST_BASE_FOLDER = "crd3" DEST_EXTRACTED_FOLDER = "CRD3-master" ZIP_FILE_NAME = "master.zip" PATH_TO_ALIGNED_DATA = "data/aligned data" PATH_TO_CLEANED_DATA = "data/cleaned data"
[docs] @classmethod def prepare(cls, split: str, max_context_size: int): dest_base_path = os.path.join( Path.home(), CRD3DialogueGeneration.DEST_BASE_FOLDER) path_to_extracted_folder = os.path.join( dest_base_path, CRD3DialogueGeneration.DEST_EXTRACTED_FOLDER) path_to_zip_file = os.path.join( dest_base_path, CRD3DialogueGeneration.ZIP_FILE_NAME) # download and extract if it does not exist if not os.path.exists(path_to_extracted_folder): os.makedirs(dest_base_path, exist_ok=True) download_file_using_url( CRD3DialogueGeneration.SOURCE_URL, path_to_zip_file) with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: zip_ref.extractall(dest_base_path) # get all train, val and test episode files ep_file_paths = { "train": os.path.join(path_to_extracted_folder, CRD3DialogueGeneration.PATH_TO_ALIGNED_DATA, "train_files"), "val": os.path.join(path_to_extracted_folder, CRD3DialogueGeneration.PATH_TO_ALIGNED_DATA, "val_files"), "test": os.path.join(path_to_extracted_folder, CRD3DialogueGeneration.PATH_TO_ALIGNED_DATA, "test_files"), } with open(ep_file_paths[split], "r") as fp: ep_file_names = fp.read().splitlines() # now, load all episode files for the split samples = [] for ep_ix, file_name in enumerate(ep_file_names): ep_path = os.path.join( path_to_extracted_folder, CRD3DialogueGeneration.PATH_TO_CLEANED_DATA, f"{file_name}.json") with open(ep_path, "r") as fp: ep_data = json.load(fp) # for each utterance, keep track of all the utterances before to get the context # initially empty for each episode contexts = [] for turn_ix, turn_data in enumerate(ep_data["TURNS"]): # current utterance string names = turn_data["NAMES"][0] turn_utterances = " ".join(turn_data["UTTERANCES"]) # sample if len(contexts) > 1: contexts_shortened = contexts[-max_context_size:] prompt_or_input = "\n".join(contexts_shortened) # add name to the prompt prompt_or_input = prompt_or_input + "\n" + names + ": " sample = Sample(id=f"{split}_ep_{ep_ix}_turn_{turn_ix}", prompt_or_input_text=prompt_or_input, references=[turn_utterances]) samples.append(sample) # update the context turn_string_full = names + ": " + turn_utterances contexts.append(turn_string_full) dp_instance = cls(samples) return dp_instance
[docs]class DailyDialog(TextGenPool): EOU_TOKEN = "<EOU>"
[docs] @classmethod def prepare(cls, split: str, context_size: int): split = CommonGen.gen_split_name(split) dataset = load_dataset("daily_dialog", split=split) samples = [] utterance_id = 0 for item in dataset: contexts = [] for utterance, emotion, intent in zip(item["dialog"], item["emotion"], item["act"]): if len(contexts) >= context_size: context = DailyDialog.EOU_TOKEN.join(contexts[-context_size:]) context += " " + DailyDialog.EOU_TOKEN target = utterance + DailyDialog.EOU_TOKEN sample = Sample(id=utterance_id, prompt_or_input_text=context, references=[target], meta_data={ "emotion": [emotion], "intent": [intent] }) samples.append(sample) contexts.append(utterance) utterance_id += 1 dp_instance = cls(samples) return dp_instance
if __name__ == "__main__": from transformers import AutoTokenizer import numpy as np dp = IMDB.prepare("test", 42) print(dp[0])