Source code for calbert.CalBERTDataset

import random
from collections import Counter
from pathlib import Path
from typing import Union, List, Tuple

import torch
from torch.utils.data import Dataset
from tqdm.auto import tqdm


[docs]class CalBERTDataset(Dataset): def __init__(self, base_language_sentences: List[str], target_language_sentences: List[str], labels: float = None, negative_sampling: bool = False, negative_sampling_size: float = 0.5, negative_sampling_count: int = 1, negative_sampling_type: str = 'target', min_count: int = 10, shuffle: bool = True): """Create a CalBERTDataset from a list of base language sentences and target language sentences. :param base_language_sentences: Sentences in the base language :param target_language_sentences: Sentences in the target (code-mixed) language :param labels: Labels (binary or similarity scores) indicating relationship between base and target sentences :param negative_sampling: Whether to perform negative sampling of examples :param negative_sampling_size: Percentage of dataset to use for negative sampling :param negative_sampling_count: Number of negative samples to sample per positive example :param negative_sampling_type: Whether to sample from the base language or the target language or both :param min_count: Minimum frequency of a token in the dataset to be included in the vocabulary :param shuffle: Whether to shuffle the dataset """ self.base_language_sentences = base_language_sentences self.target_language_sentences = target_language_sentences self.labels = torch.tensor(labels) if labels is not None else torch.ones(len(base_language_sentences)) assert len(self.base_language_sentences) == len(self.target_language_sentences) == len(self.labels) self.total_examples = len(base_language_sentences) self.negative_sampling = negative_sampling self.negative_sampling_size = negative_sampling_size self.negative_sampling_count = negative_sampling_count self.negative_sampling_type = negative_sampling_type self.min_count = min_count random.seed(0) if self.negative_sampling: self.sample_negative_examples(self.negative_sampling_type) if shuffle: items = list(zip(self.base_language_sentences, self.target_language_sentences, self.labels)) random.shuffle(items) self.base_language_sentences, self.target_language_sentences, self.labels = zip(*items) self.base_language_sentences = list(self.base_language_sentences) self.target_language_sentences = list(self.target_language_sentences) self.labels = torch.tensor(self.labels) self.tokens = list() self.compute_vocabulary(self.min_count) def __len__(self) -> int: """Returns the total number of examples in the dataset. :return: Number of examples in the dataset """ return self.total_examples def __getitem__(self, idx: int) -> Tuple[str, str, float]: """Obtain the base language sentence, target language sentence, and label at the given index in the dataset. :param idx: Index of the example in the dataset :return: A tuple of base language sentence, target language sentence, and label at the given index """ return self.base_language_sentences[idx], self.target_language_sentences[idx], self.labels[idx]
[docs] def sample_negative_examples(self, sampling: str = 'target') -> None: """Sample negative examples from the dataset for each positive example. :param sampling: Whether to sample from the base language or the target language or both :return: None """ current_examples = list(zip(self.base_language_sentences, self.target_language_sentences, self.labels)) new_examples = list() for example in tqdm(random.sample(current_examples, int(self.negative_sampling_size * self.total_examples))): translation = example[0] transliteration = example[1] if sampling == 'target' or sampling == 'both': sampled = False random_sample_items = list() while not sampled: random_sample_index = random.sample(range(self.total_examples), self.negative_sampling_count) random_sample_items = list() random_sample_target_language_sentences = list() for i in random_sample_index: random_sample_items.append(current_examples[i]) random_sample_target_language_sentences.append(self.target_language_sentences[i]) if transliteration not in random_sample_target_language_sentences: sampled = True for sample in random_sample_items: new_examples.append((translation, sample[1], 0.0)) if sampling == 'base' or sampling == 'both': sampled = False random_sample_items = list() while not sampled: random_sample_index = random.sample(range(self.total_examples), self.negative_sampling_count) random_sample_items = list() random_sample_base_language_sentences = list() for i in random_sample_index: random_sample_items.append(current_examples[i]) random_sample_base_language_sentences.append(self.base_language_sentences[i]) if translation not in random_sample_base_language_sentences: sampled = True for sample in random_sample_items: new_examples.append((sample[0], transliteration, 0.0)) current_examples += new_examples random.shuffle(current_examples) self.base_language_sentences, self.target_language_sentences, self.labels = zip(*current_examples) self.base_language_sentences = list(self.base_language_sentences) self.target_language_sentences = list(self.target_language_sentences) self.labels = torch.tensor(self.labels) self.total_examples = len(self.base_language_sentences)
[docs] def compute_vocabulary(self, min_count: int = None) -> List[str]: """Compute the vocabulary of the dataset by finding tokens appearing atleast min_count times. :param min_count: Minimum frequency of a token in the dataset to be included in the vocabulary :return: List of tokens in the dataset appearing atleast min_count times """ self.min_count = min_count if min_count is not None else self.min_count self.tokens = set(self.tokens) new_tokens = list() for sentence in tqdm(self.base_language_sentences): for word in sentence.split(): new_tokens.append(word) for sentence in tqdm(self.target_language_sentences): for word in sentence.split(): new_tokens.append(word) counter = Counter(new_tokens) filtered_new_tokens = [token for token, count in counter.items() if count >= self.min_count] self.tokens.update(filtered_new_tokens) self.tokens = list(self.tokens) return self.tokens
[docs] def get_tokens(self) -> List[str]: """Returns the vocabulary of the dataset computed by compute_vocabulary. :return: List of tokens in vocabulary """ return self.tokens
[docs] def get_batch(self, start: int, end: int) -> Tuple[List[str], List[str], torch.Tensor]: """Returns a batch of examples from the dataset between the given start and end indices. :param start: Start index of the batch in the dataset :param end: End index of the batch in the dataset :return: A tuple of base language sentences, target language sentences, and labels between the given start and end indices """ return self.base_language_sentences[start:end], self.target_language_sentences[start:end], self.labels[ start:end]
[docs] def save(self, path: Union[str, Path]) -> None: """Save the dataset object to the given path. :param path: Path to save the dataset object :return: None """ torch.save(self, path)
[docs] @staticmethod def load(path: Union[str, Path]) -> 'CalBERTDataset': """Load a CalBertDataset object from the given path. :param path: Path to load the dataset object :return: CalBertDataset object """ return torch.load(path)