66
77import torch
88from torch .utils .data import Dataset
9- from . config import *
9+ from config import *
1010from transformers import PreTrainedTokenizerFast
1111
1212
1313class NERDataset (Dataset ):
14- def __init__ (self , data , tokenizer , max_len = 128 ):
15- self .data = data
14+ def __init__ (self , data , tokenizer , max_len = 512 , exceed_strategy = "truncation" ):
1615 self .tokenizer = tokenizer
1716 self .max_len = max_len
17+ self .data = []
18+
19+ if exceed_strategy == "truncation" :
20+ self .data = data
21+ elif exceed_strategy == "sliding_window" :
22+ # only support fast tokenizer temporarily
23+ for item in data :
24+ text = item ['text' ]
25+ entities = item ['entities' ]
26+
27+ full_encoding = self .tokenizer (
28+ text ,
29+ add_special_tokens = False ,
30+ return_offsets_mapping = True ,
31+ return_tensors = "pt"
32+ )
33+
34+ tokens = full_encoding .tokens ()
35+ offset_mapping = full_encoding ["offset_mapping" ].squeeze ().tolist ()
36+
37+ if len (tokens ) <= max_len :
38+ # item['encoding'] = full_encoding
39+ self .data .append (item )
40+ continue
41+
42+ window_size = max_len
43+ stride = window_size // 2
44+
45+ start_token_idx = 0
46+ while start_token_idx < len (tokens ):
47+ end_token_idx = min (start_token_idx + window_size , len (tokens ))
48+
49+ # [start_token_idx, end_token_idx) ==> [start_char_idx, end_char_idx)
50+ start_char_idx = offset_mapping [start_token_idx ][0 ]
51+ end_char_idx = offset_mapping [end_token_idx - 1 ][1 ]
52+
53+ # 对每一个窗口, 只保留完全在当前窗口内的实体 (可能会减少窗口长度)
54+ for entity in entities :
55+ # bais: 实体长度远低于 window_size 和 stride
56+ if entity ['start' ] <= start_char_idx < entity ['end' ]:
57+ start_char_idx = entity ['end' ]
58+ if entity ['start' ] <= end_char_idx < entity ['end' ]:
59+ end_char_idx = entity ['start' ]
60+ break
61+ # start_char_idx 和 end_char_idx 也应该变化,但这里不处理
62+
63+ window_entities = []
64+ for entity in entities :
65+ if entity ['start' ] >= start_char_idx and entity ['end' ] <= end_char_idx :
66+ new_entity = entity .copy ()
67+ new_entity ['start' ] -= start_char_idx
68+ new_entity ['end' ] -= start_char_idx
69+ window_entities .append (new_entity )
70+
71+ window_text = text [start_char_idx :end_char_idx ]
72+ window_data = {
73+ 'text' : window_text ,
74+ 'entities' : window_entities # 暂时不添加 encoding
75+ }
76+ self .data .append (window_data )
77+
78+ next_token_idx = start_token_idx + stride # 重叠窗口
79+ start_token_idx = next_token_idx
80+ else :
81+ pass
1882
1983 def __len__ (self ):
2084 return len (self .data )
@@ -31,21 +95,24 @@ def __getitem__(self, idx):
3195 for i in range (start + 1 , end ):
3296 char_labels [i ] = f"I-{ entity_type } "
3397
98+ # char_labels 对齐为 token_labels
3499 if isinstance (self .tokenizer , PreTrainedTokenizerFast ):
35- encoding = self .tokenizer (
36- text ,
37- add_special_tokens = False ,
38- max_length = self .max_len ,
39- padding = "max_length" ,
40- truncation = True ,
41- return_offsets_mapping = True ,
42- return_tensors = "pt"
43- )
44-
100+ if self .data [idx ].get ('encoding' ):
101+ encoding = self .data [idx ]['encoding' ] # 预处理阶段可能得到
102+ else :
103+ encoding = self .tokenizer (
104+ text ,
105+ add_special_tokens = False ,
106+ max_length = self .max_len ,
107+ padding = "max_length" ,
108+ truncation = True ,
109+ return_offsets_mapping = True ,
110+ return_tensors = "pt"
111+ )
45112 input_ids = encoding ["input_ids" ].squeeze ()
46113 attention_mask = encoding ["attention_mask" ].squeeze ()
47- tokens = self . tokenizer . convert_ids_to_tokens ( input_ids )
48- offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
114+ tokens = encoding . tokens ( )
115+ offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist () # 每个 token 在原文中的位置
49116
50117 # 从实体得到 token_labels
51118 token_labels = []
@@ -70,7 +137,7 @@ def __getitem__(self, idx):
70137
71138 input_ids = encoding ["input_ids" ].squeeze ()
72139 attention_mask = encoding ["attention_mask" ].squeeze ()
73- tokens = self . tokenizer . convert_ids_to_tokens ( input_ids )
140+ tokens = encoding . tokens ( )
74141
75142 token_labels = []
76143 char_idx = 0
@@ -94,7 +161,7 @@ def __getitem__(self, idx):
94161
95162
96163class REDataset (Dataset ):
97- def __init__ (self , data , tokenizer , max_len = 128 ):
164+ def __init__ (self , data , tokenizer , max_len = 512 , exceed_strategy = "truncation" ):
98165 self .tokenizer = tokenizer
99166 self .max_len = max_len
100167
@@ -103,13 +170,19 @@ def __init__(self, data, tokenizer, max_len=128):
103170 text = line ['text' ]
104171 entities = line ['entities' ]
105172 relations = line ['relations' ]
106- for relation in relations :
107- self .data .append ({
108- 'text' : text ,
109- 'e1' : next (filter (lambda x : x ['id' ] == relation ['source_id' ], entities )),
110- 'e2' : next (filter (lambda x : x ['id' ] == relation ['target_id' ], entities )),
111- 'relation' : relation ['type' ]
112- })
173+
174+ if exceed_strategy == "truncation" :
175+ for relation in relations :
176+ self .data .append ({
177+ 'text' : text ,
178+ 'e1' : next (filter (lambda x : x ['id' ] == relation ['source_id' ], entities )),
179+ 'e2' : next (filter (lambda x : x ['id' ] == relation ['target_id' ], entities )),
180+ 'relation' : relation ['type' ]
181+ })
182+ elif exceed_strategy == "sliding_window" :
183+ pass
184+ else :
185+ pass
113186
114187 def __len__ (self ):
115188 return len (self .data )
@@ -139,11 +212,11 @@ def __getitem__(self, idx):
139212 attention_mask = encoding ["attention_mask" ].squeeze ()
140213 offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
141214
142- e1_mask = _create_entity_mask (input_ids , offset_mapping , e1_start , e1_end )
215+ e1_mask = _create_entity_mask (input_ids , offset_mapping , e1_start , e1_end ) # 实体掩码为 1
143216 e2_mask = _create_entity_mask (input_ids , offset_mapping , e2_start , e2_end )
144217
145218 else :
146-
219+
147220 encoding = self .tokenizer (
148221 text ,
149222 add_special_tokens = False ,
@@ -156,10 +229,10 @@ def __getitem__(self, idx):
156229 input_ids = encoding ["input_ids" ].squeeze ()
157230 attention_mask = encoding ["attention_mask" ].squeeze ()
158231 tokens = self .tokenizer .convert_ids_to_tokens (input_ids )
159-
232+
160233 e1_mask = _create_entity_mask2 (text , input_ids , tokens , e1_start , e1_end )
161234 e2_mask = _create_entity_mask2 (text , input_ids , tokens , e2_start , e2_end )
162-
235+
163236 return {
164237 "input_ids" : input_ids ,
165238 "attention_mask" : attention_mask ,
0 commit comments