-
-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy pathbert_loader.py
More file actions
83 lines (53 loc) · 2.12 KB
/
bert_loader.py
File metadata and controls
83 lines (53 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Dict, List
from torch import LongTensor
from transformers import DataCollatorForLanguageModeling
class BERTIterator:
def __init__(self, dataset_reader, batch_size: int, sentence_len: int):
self.dataset_reader = dataset_reader
self.batch_size = batch_size
self.sentence_len = sentence_len
self.data_collator = DataCollatorForLanguageModeling(
tokenizer=self.dataset_reader.encoder.tokenizer_ref,
mlm = True,
mlm_probability = 0.15)
def load(self, dataset_meta) -> LongTensor:
self.dataset_reader.read(dataset_meta)
#In case user wants to display the data
return self.dataset_reader.encoded_text
def __iter__(self):
self.index = 0
return self
def __next__(self):
if self.index + self.batch_size > self.num_examples:
raise StopIteration
batch_examples = []
for i in range(self.batch_size):
example = self._load_example()
batch_examples.append(example)
batch = self._collate(batch_examples=batch_examples)
return batch
@property
def num_examples(self):
"""Returns that number of non-overlapping examples
in the dataset
"""
num_examples = (len(self.dataset_reader.encoded_text) - 1) // self.sentence_len
return num_examples
@property
def num_batches(self):
"""Returns the total number of batches. The last batch
is dropped if its size is less than self.batch_size.
"""
num_batches = self.num_examples // self.batch_size
return num_batches
def _load_example(self) -> LongTensor:
# LongTensor containing the dataset
dataset = self.dataset_reader.encoded_text
#Getting an example - sequence of length 'sentence_len'
example = dataset.narrow(
dim=0, start=self.index * self.sentence_len, length=self.sentence_len
)
self.index += 1
return example
def _collate(self, batch_examples: List) -> Dict:
return self.data_collator(batch_examples)