temporal supported
This commit is contained in:
parent
54e6fbc84c
commit
5f1518cfd9
15
data/icews14/about.txt
Normal file
15
data/icews14/about.txt
Normal file
@ -0,0 +1,15 @@
|
||||
# triples: 89320
|
||||
# entities: 7128
|
||||
# relations: 12409
|
||||
# timesteps: 208
|
||||
# test triples: 8255
|
||||
# valid triples: 8239
|
||||
# train triples: 72826
|
||||
Measure method: N/A
|
||||
Target Size : 0
|
||||
Grow Factor: 0
|
||||
Shrink Factor: 0
|
||||
Epsilon Factor: 0
|
||||
Search method: N/A
|
||||
filter_dupes: inter
|
||||
nonames: False
|
7128
data/icews14/entities.dict
Normal file
7128
data/icews14/entities.dict
Normal file
File diff suppressed because it is too large
Load Diff
12409
data/icews14/relations.dict
Normal file
12409
data/icews14/relations.dict
Normal file
File diff suppressed because it is too large
Load Diff
8255
data/icews14/test.txt
Normal file
8255
data/icews14/test.txt
Normal file
File diff suppressed because it is too large
Load Diff
209
data/icews14/time_map.dict
Normal file
209
data/icews14/time_map.dict
Normal file
@ -0,0 +1,209 @@
|
||||
0 0 2
|
||||
1 3 5
|
||||
2 6 7
|
||||
3 8 9
|
||||
4 10 12
|
||||
5 13 14
|
||||
6 15 16
|
||||
7 17 19
|
||||
8 20 21
|
||||
9 22 23
|
||||
10 24 26
|
||||
11 27 28
|
||||
12 29 30
|
||||
13 31 33
|
||||
14 34 35
|
||||
15 36 37
|
||||
16 38 40
|
||||
17 41 42
|
||||
18 43 44
|
||||
19 45 46
|
||||
20 47 48
|
||||
21 49 49
|
||||
22 50 50
|
||||
23 51 51
|
||||
24 52 53
|
||||
25 54 54
|
||||
26 55 55
|
||||
27 56 57
|
||||
28 58 59
|
||||
29 60 61
|
||||
30 62 62
|
||||
31 63 63
|
||||
32 64 65
|
||||
33 66 68
|
||||
34 69 70
|
||||
35 71 71
|
||||
36 72 72
|
||||
37 73 74
|
||||
38 75 76
|
||||
39 77 78
|
||||
40 79 80
|
||||
41 81 82
|
||||
42 83 84
|
||||
43 85 85
|
||||
44 86 87
|
||||
45 88 89
|
||||
46 90 91
|
||||
47 92 93
|
||||
48 94 96
|
||||
49 97 97
|
||||
50 98 99
|
||||
51 100 101
|
||||
52 102 103
|
||||
53 104 105
|
||||
54 106 107
|
||||
55 108 110
|
||||
56 111 112
|
||||
57 113 114
|
||||
58 115 116
|
||||
59 117 118
|
||||
60 119 119
|
||||
61 120 121
|
||||
62 122 124
|
||||
63 125 125
|
||||
64 126 127
|
||||
65 128 129
|
||||
66 130 131
|
||||
67 132 133
|
||||
68 134 135
|
||||
69 136 138
|
||||
70 139 139
|
||||
71 140 140
|
||||
72 141 141
|
||||
73 142 143
|
||||
74 144 145
|
||||
75 146 147
|
||||
76 148 148
|
||||
77 149 150
|
||||
78 151 152
|
||||
79 153 154
|
||||
80 155 155
|
||||
81 156 157
|
||||
82 158 159
|
||||
83 160 161
|
||||
84 162 163
|
||||
85 164 166
|
||||
86 167 167
|
||||
87 168 168
|
||||
88 169 169
|
||||
89 170 170
|
||||
90 171 173
|
||||
91 174 175
|
||||
92 176 177
|
||||
93 178 180
|
||||
94 181 182
|
||||
95 183 183
|
||||
96 184 185
|
||||
97 186 187
|
||||
98 188 188
|
||||
99 189 190
|
||||
100 191 192
|
||||
101 193 194
|
||||
102 195 195
|
||||
103 196 197
|
||||
104 198 199
|
||||
105 200 201
|
||||
106 202 203
|
||||
107 204 205
|
||||
108 206 208
|
||||
109 209 210
|
||||
110 211 212
|
||||
111 213 215
|
||||
112 216 217
|
||||
113 218 219
|
||||
114 220 221
|
||||
115 222 222
|
||||
116 223 224
|
||||
117 225 226
|
||||
118 227 229
|
||||
119 230 231
|
||||
120 232 233
|
||||
121 234 236
|
||||
122 237 238
|
||||
123 239 239
|
||||
124 240 241
|
||||
125 242 243
|
||||
126 244 245
|
||||
127 246 246
|
||||
128 247 248
|
||||
129 249 250
|
||||
130 251 251
|
||||
131 252 252
|
||||
132 253 253
|
||||
133 254 254
|
||||
134 255 256
|
||||
135 257 257
|
||||
136 258 259
|
||||
137 260 261
|
||||
138 262 263
|
||||
139 264 264
|
||||
140 265 265
|
||||
141 266 266
|
||||
142 267 267
|
||||
143 268 269
|
||||
144 270 271
|
||||
145 272 272
|
||||
146 273 273
|
||||
147 274 274
|
||||
148 275 276
|
||||
149 277 278
|
||||
150 279 279
|
||||
151 280 281
|
||||
152 282 283
|
||||
153 284 285
|
||||
154 286 286
|
||||
155 287 287
|
||||
156 288 288
|
||||
157 289 289
|
||||
158 290 291
|
||||
159 292 292
|
||||
160 293 293
|
||||
161 294 294
|
||||
162 295 295
|
||||
163 296 297
|
||||
164 298 299
|
||||
165 300 300
|
||||
166 301 301
|
||||
167 302 303
|
||||
168 304 305
|
||||
169 306 307
|
||||
170 308 309
|
||||
171 310 310
|
||||
172 311 312
|
||||
173 313 313
|
||||
174 314 314
|
||||
175 315 315
|
||||
176 316 316
|
||||
177 317 317
|
||||
178 318 319
|
||||
179 320 320
|
||||
180 321 321
|
||||
181 322 322
|
||||
182 323 323
|
||||
183 324 324
|
||||
184 325 326
|
||||
185 327 327
|
||||
186 328 328
|
||||
187 329 329
|
||||
188 330 330
|
||||
189 331 332
|
||||
190 333 334
|
||||
191 335 335
|
||||
192 336 336
|
||||
193 337 338
|
||||
194 339 340
|
||||
195 341 342
|
||||
196 343 343
|
||||
197 344 344
|
||||
198 345 346
|
||||
199 347 348
|
||||
200 349 349
|
||||
201 350 350
|
||||
202 351 352
|
||||
203 353 355
|
||||
204 356 357
|
||||
205 358 359
|
||||
206 360 362
|
||||
207 363 365
|
||||
208 366 366
|
72826
data/icews14/train.txt
Normal file
72826
data/icews14/train.txt
Normal file
File diff suppressed because it is too large
Load Diff
8239
data/icews14/valid.txt
Normal file
8239
data/icews14/valid.txt
Normal file
File diff suppressed because it is too large
Load Diff
15
data/wikidata12k/about.txt
Normal file
15
data/wikidata12k/about.txt
Normal file
@ -0,0 +1,15 @@
|
||||
# triples: 291818
|
||||
# entities: 12554
|
||||
# relations: 423
|
||||
# timesteps: 70
|
||||
# test triples: 19271
|
||||
# valid triples: 20208
|
||||
# train triples: 252339
|
||||
Measure method: N/A
|
||||
Target Size : 423
|
||||
Grow Factor: 0
|
||||
Shrink Factor: 4.0
|
||||
Epsilon Factor: 0
|
||||
Search method: N/A
|
||||
filter_dupes: inter
|
||||
nonames: False
|
12554
data/wikidata12k/entities.dict
Normal file
12554
data/wikidata12k/entities.dict
Normal file
File diff suppressed because it is too large
Load Diff
423
data/wikidata12k/relations.dict
Normal file
423
data/wikidata12k/relations.dict
Normal file
@ -0,0 +1,423 @@
|
||||
0 P131[0-0]
|
||||
1 P131[1-1]
|
||||
2 P131[2-2]
|
||||
3 P131[3-3]
|
||||
4 P131[4-4]
|
||||
5 P131[5-5]
|
||||
6 P131[6-6]
|
||||
7 P131[7-7]
|
||||
8 P131[8-8]
|
||||
9 P131[9-9]
|
||||
10 P131[10-10]
|
||||
11 P131[11-11]
|
||||
12 P131[12-12]
|
||||
13 P131[13-13]
|
||||
14 P131[14-14]
|
||||
15 P131[15-15]
|
||||
16 P131[16-16]
|
||||
17 P131[17-17]
|
||||
18 P131[18-18]
|
||||
19 P131[19-19]
|
||||
20 P131[20-20]
|
||||
21 P131[21-21]
|
||||
22 P131[22-22]
|
||||
23 P131[23-23]
|
||||
24 P131[24-24]
|
||||
25 P131[25-25]
|
||||
26 P131[26-26]
|
||||
27 P131[27-27]
|
||||
28 P131[28-28]
|
||||
29 P131[29-29]
|
||||
30 P131[30-30]
|
||||
31 P131[31-31]
|
||||
32 P131[32-32]
|
||||
33 P131[33-33]
|
||||
34 P131[34-34]
|
||||
35 P131[35-35]
|
||||
36 P131[36-36]
|
||||
37 P131[37-37]
|
||||
38 P131[38-38]
|
||||
39 P131[39-39]
|
||||
40 P131[40-40]
|
||||
41 P131[41-41]
|
||||
42 P131[42-42]
|
||||
43 P131[43-43]
|
||||
44 P131[44-44]
|
||||
45 P131[45-45]
|
||||
46 P131[46-46]
|
||||
47 P131[47-47]
|
||||
48 P131[48-48]
|
||||
49 P131[49-49]
|
||||
50 P131[50-50]
|
||||
51 P131[51-51]
|
||||
52 P131[52-52]
|
||||
53 P131[53-53]
|
||||
54 P131[54-54]
|
||||
55 P131[55-55]
|
||||
56 P131[56-56]
|
||||
57 P131[57-57]
|
||||
58 P131[58-58]
|
||||
59 P131[59-59]
|
||||
60 P131[60-60]
|
||||
61 P131[61-61]
|
||||
62 P131[62-62]
|
||||
63 P131[63-63]
|
||||
64 P131[64-64]
|
||||
65 P131[65-65]
|
||||
66 P131[66-66]
|
||||
67 P131[67-67]
|
||||
68 P131[68-68]
|
||||
69 P131[69-69]
|
||||
70 P1435[65-65]
|
||||
71 P39[49-49]
|
||||
72 P39[50-50]
|
||||
73 P39[51-51]
|
||||
74 P39[52-52]
|
||||
75 P39[53-53]
|
||||
76 P39[54-54]
|
||||
77 P39[55-55]
|
||||
78 P39[56-56]
|
||||
79 P39[57-57]
|
||||
80 P39[58-58]
|
||||
81 P39[59-59]
|
||||
82 P39[60-60]
|
||||
83 P39[61-61]
|
||||
84 P39[62-62]
|
||||
85 P39[63-63]
|
||||
86 P39[64-64]
|
||||
87 P39[65-65]
|
||||
88 P39[66-66]
|
||||
89 P39[67-67]
|
||||
90 P39[68-68]
|
||||
91 P39[69-69]
|
||||
92 P54[40-40]
|
||||
93 P54[41-41]
|
||||
94 P54[42-42]
|
||||
95 P54[43-43]
|
||||
96 P54[44-44]
|
||||
97 P54[45-45]
|
||||
98 P54[46-46]
|
||||
99 P54[47-47]
|
||||
100 P54[48-48]
|
||||
101 P54[49-49]
|
||||
102 P54[50-50]
|
||||
103 P54[51-51]
|
||||
104 P54[52-52]
|
||||
105 P54[53-53]
|
||||
106 P54[54-54]
|
||||
107 P54[55-55]
|
||||
108 P54[56-56]
|
||||
109 P54[57-57]
|
||||
110 P54[58-58]
|
||||
111 P54[59-59]
|
||||
112 P54[60-60]
|
||||
113 P54[61-61]
|
||||
114 P54[62-62]
|
||||
115 P54[63-63]
|
||||
116 P54[64-64]
|
||||
117 P54[65-65]
|
||||
118 P54[66-66]
|
||||
119 P54[67-67]
|
||||
120 P54[68-68]
|
||||
121 P54[69-69]
|
||||
122 P31[0-0]
|
||||
123 P31[1-1]
|
||||
124 P31[2-2]
|
||||
125 P31[3-3]
|
||||
126 P31[4-4]
|
||||
127 P31[5-5]
|
||||
128 P31[6-6]
|
||||
129 P31[7-7]
|
||||
130 P31[8-8]
|
||||
131 P31[9-9]
|
||||
132 P31[10-10]
|
||||
133 P31[11-11]
|
||||
134 P31[12-12]
|
||||
135 P31[13-13]
|
||||
136 P31[14-14]
|
||||
137 P31[15-15]
|
||||
138 P31[16-16]
|
||||
139 P31[17-17]
|
||||
140 P31[18-18]
|
||||
141 P31[19-19]
|
||||
142 P31[20-20]
|
||||
143 P31[21-21]
|
||||
144 P31[22-22]
|
||||
145 P31[23-23]
|
||||
146 P31[24-24]
|
||||
147 P31[25-25]
|
||||
148 P31[26-26]
|
||||
149 P31[27-27]
|
||||
150 P31[28-28]
|
||||
151 P31[29-29]
|
||||
152 P31[30-30]
|
||||
153 P31[31-31]
|
||||
154 P31[32-32]
|
||||
155 P31[33-33]
|
||||
156 P31[34-34]
|
||||
157 P31[35-35]
|
||||
158 P31[36-36]
|
||||
159 P31[37-37]
|
||||
160 P31[38-38]
|
||||
161 P31[39-39]
|
||||
162 P31[40-40]
|
||||
163 P31[41-41]
|
||||
164 P31[42-42]
|
||||
165 P31[43-43]
|
||||
166 P31[44-44]
|
||||
167 P31[45-45]
|
||||
168 P31[46-46]
|
||||
169 P31[47-47]
|
||||
170 P31[48-48]
|
||||
171 P31[49-49]
|
||||
172 P31[50-50]
|
||||
173 P31[51-51]
|
||||
174 P31[52-52]
|
||||
175 P31[53-53]
|
||||
176 P31[54-54]
|
||||
177 P31[55-55]
|
||||
178 P31[56-56]
|
||||
179 P31[57-57]
|
||||
180 P31[58-58]
|
||||
181 P31[59-59]
|
||||
182 P31[60-60]
|
||||
183 P31[61-61]
|
||||
184 P31[62-62]
|
||||
185 P31[63-63]
|
||||
186 P31[64-64]
|
||||
187 P31[65-65]
|
||||
188 P31[66-66]
|
||||
189 P31[67-67]
|
||||
190 P31[68-68]
|
||||
191 P31[69-69]
|
||||
192 P463[26-26]
|
||||
193 P463[27-27]
|
||||
194 P463[28-28]
|
||||
195 P463[29-29]
|
||||
196 P463[30-30]
|
||||
197 P463[31-31]
|
||||
198 P463[32-32]
|
||||
199 P463[33-33]
|
||||
200 P463[34-34]
|
||||
201 P463[35-35]
|
||||
202 P463[36-36]
|
||||
203 P463[37-37]
|
||||
204 P463[38-38]
|
||||
205 P463[39-39]
|
||||
206 P463[40-40]
|
||||
207 P463[41-41]
|
||||
208 P463[42-42]
|
||||
209 P463[43-43]
|
||||
210 P463[44-44]
|
||||
211 P463[45-45]
|
||||
212 P463[46-46]
|
||||
213 P463[47-47]
|
||||
214 P463[48-48]
|
||||
215 P463[49-49]
|
||||
216 P463[50-50]
|
||||
217 P463[51-51]
|
||||
218 P463[52-52]
|
||||
219 P463[53-53]
|
||||
220 P463[54-54]
|
||||
221 P463[55-55]
|
||||
222 P463[56-56]
|
||||
223 P463[57-57]
|
||||
224 P463[58-58]
|
||||
225 P463[59-59]
|
||||
226 P463[60-60]
|
||||
227 P463[61-61]
|
||||
228 P463[62-62]
|
||||
229 P463[63-63]
|
||||
230 P463[64-64]
|
||||
231 P463[65-65]
|
||||
232 P463[66-66]
|
||||
233 P463[67-67]
|
||||
234 P463[68-68]
|
||||
235 P463[69-69]
|
||||
236 P512[4-69]
|
||||
237 P190[0-29]
|
||||
238 P150[0-3]
|
||||
239 P1376[39-47]
|
||||
240 P463[0-7]
|
||||
241 P166[0-7]
|
||||
242 P2962[18-30]
|
||||
243 P108[29-36]
|
||||
244 P39[0-3]
|
||||
245 P17[47-48]
|
||||
246 P166[21-23]
|
||||
247 P793[46-69]
|
||||
248 P69[32-41]
|
||||
249 P17[57-58]
|
||||
250 P190[42-45]
|
||||
251 P2962[39-42]
|
||||
252 P54[0-18]
|
||||
253 P26[56-61]
|
||||
254 P150[14-17]
|
||||
255 P463[16-17]
|
||||
256 P26[39-46]
|
||||
257 P579[36-43]
|
||||
258 P579[16-23]
|
||||
259 P2962[59-60]
|
||||
260 P1411[59-61]
|
||||
261 P26[20-27]
|
||||
262 P6[4-69]
|
||||
263 P1435[33-34]
|
||||
264 P166[52-53]
|
||||
265 P108[49-57]
|
||||
266 P150[10-13]
|
||||
267 P1346[47-68]
|
||||
268 P150[18-21]
|
||||
269 P1346[13-46]
|
||||
270 P69[20-23]
|
||||
271 P39[31-32]
|
||||
272 P1411[32-37]
|
||||
273 P166[62-63]
|
||||
274 P150[44-47]
|
||||
275 P2962[61-62]
|
||||
276 P150[48-51]
|
||||
277 P150[52-55]
|
||||
278 P1411[62-67]
|
||||
279 P1435[35-36]
|
||||
280 P1411[48-51]
|
||||
281 P150[22-25]
|
||||
282 P2962[63-64]
|
||||
283 P2962[65-66]
|
||||
284 P166[58-59]
|
||||
285 P190[46-49]
|
||||
286 P54[34-35]
|
||||
287 P1435[4-16]
|
||||
288 P463[18-19]
|
||||
289 P150[31-34]
|
||||
290 P150[35-38]
|
||||
291 P39[35-36]
|
||||
292 P26[62-69]
|
||||
293 P1411[56-58]
|
||||
294 P1435[37-38]
|
||||
295 P166[60-61]
|
||||
296 P39[33-34]
|
||||
297 P102[24-31]
|
||||
298 P2962[43-46]
|
||||
299 P108[37-48]
|
||||
300 P190[50-53]
|
||||
301 P39[4-6]
|
||||
302 P1435[39-40]
|
||||
303 P793[0-45]
|
||||
304 P150[64-69]
|
||||
305 P39[19-22]
|
||||
306 P27[30-38]
|
||||
307 P2962[31-38]
|
||||
308 P1411[24-31]
|
||||
309 P102[40-45]
|
||||
310 P39[37-38]
|
||||
311 P463[8-11]
|
||||
312 P1435[41-42]
|
||||
313 P27[52-59]
|
||||
314 P69[16-19]
|
||||
315 P17[16-18]
|
||||
316 P190[54-57]
|
||||
317 P1435[43-44]
|
||||
318 P166[8-15]
|
||||
319 P166[45-47]
|
||||
320 P2962[47-50]
|
||||
321 P39[39-40]
|
||||
322 P1411[52-55]
|
||||
323 P108[58-69]
|
||||
324 P463[20-21]
|
||||
325 P39[41-42]
|
||||
326 P150[26-30]
|
||||
327 P150[39-43]
|
||||
328 P1435[45-46]
|
||||
329 P26[28-38]
|
||||
330 P54[27-30]
|
||||
331 P190[58-61]
|
||||
332 P17[59-61]
|
||||
333 P54[36-37]
|
||||
334 P166[16-20]
|
||||
335 P166[37-40]
|
||||
336 P1435[47-48]
|
||||
337 P17[0-3]
|
||||
338 P26[47-55]
|
||||
339 P1435[49-50]
|
||||
340 P1435[25-28]
|
||||
341 P150[4-9]
|
||||
342 P102[63-69]
|
||||
343 P26[0-19]
|
||||
344 P1435[17-24]
|
||||
345 P39[23-26]
|
||||
346 P1435[51-52]
|
||||
347 P39[7-11]
|
||||
348 P69[12-15]
|
||||
349 P69[24-31]
|
||||
350 P102[0-23]
|
||||
351 P39[43-44]
|
||||
352 P579[24-35]
|
||||
353 P190[62-65]
|
||||
354 P1435[53-54]
|
||||
355 P1376[0-18]
|
||||
356 P27[0-14]
|
||||
357 P463[12-15]
|
||||
358 P166[33-36]
|
||||
359 P102[32-39]
|
||||
360 P17[4-7]
|
||||
361 P190[30-41]
|
||||
362 P166[24-28]
|
||||
363 P190[66-69]
|
||||
364 P69[42-69]
|
||||
365 P1435[55-56]
|
||||
366 P54[31-33]
|
||||
367 P39[45-46]
|
||||
368 P17[12-15]
|
||||
369 P1435[57-58]
|
||||
370 P54[19-26]
|
||||
371 P2962[51-54]
|
||||
372 P2962[67-69]
|
||||
373 P1435[59-60]
|
||||
374 P579[44-56]
|
||||
375 P1435[61-62]
|
||||
376 P166[41-44]
|
||||
377 P17[19-22]
|
||||
378 P1376[19-38]
|
||||
379 P17[23-26]
|
||||
380 P1376[48-69]
|
||||
381 P463[22-23]
|
||||
382 P17[27-30]
|
||||
383 P1435[63-64]
|
||||
384 P69[0-3]
|
||||
385 P1435[66-67]
|
||||
386 P17[35-38]
|
||||
387 P69[8-11]
|
||||
388 P1435[68-69]
|
||||
389 P17[31-34]
|
||||
390 P102[46-53]
|
||||
391 P27[60-69]
|
||||
392 P579[57-69]
|
||||
393 P69[4-7]
|
||||
394 P1411[7-14]
|
||||
395 P551[0-35]
|
||||
396 P108[0-28]
|
||||
397 P17[8-11]
|
||||
398 P1411[38-47]
|
||||
399 P17[43-46]
|
||||
400 P17[49-52]
|
||||
401 P166[64-69]
|
||||
402 P1435[29-32]
|
||||
403 P54[38-39]
|
||||
404 P39[27-30]
|
||||
405 P2962[55-58]
|
||||
406 P463[24-25]
|
||||
407 P17[39-42]
|
||||
408 P17[53-56]
|
||||
409 P17[66-69]
|
||||
410 P17[62-65]
|
||||
411 P1411[15-23]
|
||||
412 P166[48-51]
|
||||
413 P27[15-29]
|
||||
414 P150[56-63]
|
||||
415 P27[39-51]
|
||||
416 P39[47-48]
|
||||
417 P166[29-32]
|
||||
418 P39[12-18]
|
||||
419 P166[54-57]
|
||||
420 P551[36-69]
|
||||
421 P579[0-15]
|
||||
422 P102[54-62]
|
19271
data/wikidata12k/test.txt
Normal file
19271
data/wikidata12k/test.txt
Normal file
File diff suppressed because it is too large
Load Diff
71
data/wikidata12k/time_map.dict
Normal file
71
data/wikidata12k/time_map.dict
Normal file
@ -0,0 +1,71 @@
|
||||
0 19 19
|
||||
1 20 1643
|
||||
2 1644 1790
|
||||
3 1791 1816
|
||||
4 1817 1855
|
||||
5 1856 1871
|
||||
6 1872 1893
|
||||
7 1894 1905
|
||||
8 1906 1913
|
||||
9 1914 1918
|
||||
10 1919 1920
|
||||
11 1921 1924
|
||||
12 1925 1929
|
||||
13 1930 1933
|
||||
14 1934 1937
|
||||
15 1938 1941
|
||||
16 1942 1945
|
||||
17 1946 1948
|
||||
18 1949 1950
|
||||
19 1951 1953
|
||||
20 1954 1956
|
||||
21 1957 1959
|
||||
22 1960 1961
|
||||
23 1962 1963
|
||||
24 1964 1965
|
||||
25 1966 1967
|
||||
26 1968 1968
|
||||
27 1969 1970
|
||||
28 1971 1972
|
||||
29 1973 1974
|
||||
30 1975 1976
|
||||
31 1977 1978
|
||||
32 1979 1980
|
||||
33 1981 1982
|
||||
34 1983 1983
|
||||
35 1984 1984
|
||||
36 1985 1985
|
||||
37 1986 1986
|
||||
38 1987 1987
|
||||
39 1988 1988
|
||||
40 1989 1989
|
||||
41 1990 1990
|
||||
42 1991 1991
|
||||
43 1992 1992
|
||||
44 1993 1993
|
||||
45 1994 1994
|
||||
46 1995 1995
|
||||
47 1996 1996
|
||||
48 1997 1997
|
||||
49 1998 1998
|
||||
50 1999 1999
|
||||
51 2000 2000
|
||||
52 2001 2001
|
||||
53 2002 2002
|
||||
54 2003 2003
|
||||
55 2004 2004
|
||||
56 2005 2005
|
||||
57 2006 2006
|
||||
58 2007 2007
|
||||
59 2008 2008
|
||||
60 2009 2009
|
||||
61 2010 2010
|
||||
62 2011 2011
|
||||
63 2012 2012
|
||||
64 2013 2013
|
||||
65 2014 2014
|
||||
66 2015 2015
|
||||
67 2016 2016
|
||||
68 2017 2017
|
||||
69 2018 2020
|
||||
70 2021 2021
|
252339
data/wikidata12k/train.txt
Normal file
252339
data/wikidata12k/train.txt
Normal file
File diff suppressed because it is too large
Load Diff
20208
data/wikidata12k/valid.txt
Normal file
20208
data/wikidata12k/valid.txt
Normal file
File diff suppressed because it is too large
Load Diff
15
data/yago11k/about.txt
Normal file
15
data/yago11k/about.txt
Normal file
@ -0,0 +1,15 @@
|
||||
# triples: 78032
|
||||
# entities: 10526
|
||||
# relations: 177
|
||||
# timesteps: 46
|
||||
# test triples: 6909
|
||||
# valid triples: 7198
|
||||
# train triples: 63925
|
||||
Measure method: N/A
|
||||
Target Size : 0
|
||||
Grow Factor: 0
|
||||
Shrink Factor: 0
|
||||
Epsilon Factor: 5.0
|
||||
Search method: N/A
|
||||
filter_dupes: inter
|
||||
nonames: False
|
10526
data/yago11k/entities.dict
Normal file
10526
data/yago11k/entities.dict
Normal file
File diff suppressed because it is too large
Load Diff
177
data/yago11k/relations.dict
Normal file
177
data/yago11k/relations.dict
Normal file
@ -0,0 +1,177 @@
|
||||
0 <wasBornIn>[0-2]
|
||||
1 <wasBornIn>[2-5]
|
||||
2 <wasBornIn>[5-7]
|
||||
3 <wasBornIn>[7-10]
|
||||
4 <wasBornIn>[10-12]
|
||||
5 <wasBornIn>[12-15]
|
||||
6 <wasBornIn>[15-17]
|
||||
7 <wasBornIn>[17-20]
|
||||
8 <wasBornIn>[20-22]
|
||||
9 <wasBornIn>[22-25]
|
||||
10 <wasBornIn>[25-27]
|
||||
11 <wasBornIn>[27-30]
|
||||
12 <wasBornIn>[30-32]
|
||||
13 <wasBornIn>[32-35]
|
||||
14 <wasBornIn>[35-45]
|
||||
15 <wasBornIn>[52-52]
|
||||
16 <diedIn>[0-3]
|
||||
17 <diedIn>[3-5]
|
||||
18 <diedIn>[5-7]
|
||||
19 <diedIn>[7-10]
|
||||
20 <diedIn>[10-12]
|
||||
21 <diedIn>[12-14]
|
||||
22 <diedIn>[14-17]
|
||||
23 <diedIn>[17-19]
|
||||
24 <diedIn>[19-21]
|
||||
25 <diedIn>[21-23]
|
||||
26 <diedIn>[23-25]
|
||||
27 <diedIn>[25-27]
|
||||
28 <diedIn>[27-29]
|
||||
29 <diedIn>[29-32]
|
||||
30 <diedIn>[32-34]
|
||||
31 <diedIn>[34-36]
|
||||
32 <diedIn>[36-38]
|
||||
33 <diedIn>[38-40]
|
||||
34 <diedIn>[40-42]
|
||||
35 <diedIn>[42-44]
|
||||
36 <diedIn>[44-47]
|
||||
37 <diedIn>[47-49]
|
||||
38 <diedIn>[49-51]
|
||||
39 <diedIn>[51-53]
|
||||
40 <diedIn>[53-55]
|
||||
41 <diedIn>[55-57]
|
||||
42 <diedIn>[59-59]
|
||||
43 <worksAt>[0-3]
|
||||
44 <worksAt>[3-5]
|
||||
45 <worksAt>[5-7]
|
||||
46 <worksAt>[7-10]
|
||||
47 <worksAt>[10-12]
|
||||
48 <worksAt>[12-14]
|
||||
49 <worksAt>[14-17]
|
||||
50 <worksAt>[17-19]
|
||||
51 <worksAt>[19-21]
|
||||
52 <worksAt>[21-23]
|
||||
53 <worksAt>[23-25]
|
||||
54 <worksAt>[25-27]
|
||||
55 <worksAt>[27-29]
|
||||
56 <worksAt>[29-32]
|
||||
57 <worksAt>[32-34]
|
||||
58 <worksAt>[34-36]
|
||||
59 <worksAt>[36-40]
|
||||
60 <worksAt>[40-42]
|
||||
61 <worksAt>[42-47]
|
||||
62 <worksAt>[47-53]
|
||||
63 <worksAt>[59-59]
|
||||
64 <playsFor>[0-3]
|
||||
65 <playsFor>[3-5]
|
||||
66 <playsFor>[5-23]
|
||||
67 <playsFor>[23-25]
|
||||
68 <playsFor>[25-27]
|
||||
69 <playsFor>[27-29]
|
||||
70 <playsFor>[29-32]
|
||||
71 <playsFor>[32-34]
|
||||
72 <playsFor>[34-36]
|
||||
73 <playsFor>[36-38]
|
||||
74 <playsFor>[38-40]
|
||||
75 <playsFor>[40-42]
|
||||
76 <playsFor>[42-44]
|
||||
77 <playsFor>[44-47]
|
||||
78 <playsFor>[47-51]
|
||||
79 <playsFor>[59-59]
|
||||
80 <hasWonPrize>[1-4]
|
||||
81 <hasWonPrize>[4-6]
|
||||
82 <hasWonPrize>[6-8]
|
||||
83 <hasWonPrize>[8-11]
|
||||
84 <hasWonPrize>[11-15]
|
||||
85 <hasWonPrize>[15-18]
|
||||
86 <hasWonPrize>[18-22]
|
||||
87 <hasWonPrize>[22-26]
|
||||
88 <hasWonPrize>[26-30]
|
||||
89 <hasWonPrize>[30-33]
|
||||
90 <hasWonPrize>[33-37]
|
||||
91 <hasWonPrize>[37-47]
|
||||
92 <hasWonPrize>[47-53]
|
||||
93 <hasWonPrize>[59-59]
|
||||
94 <isMarriedTo>[0-3]
|
||||
95 <isMarriedTo>[3-5]
|
||||
96 <isMarriedTo>[5-7]
|
||||
97 <isMarriedTo>[7-10]
|
||||
98 <isMarriedTo>[10-12]
|
||||
99 <isMarriedTo>[12-14]
|
||||
100 <isMarriedTo>[14-17]
|
||||
101 <isMarriedTo>[17-19]
|
||||
102 <isMarriedTo>[19-21]
|
||||
103 <isMarriedTo>[21-23]
|
||||
104 <isMarriedTo>[23-25]
|
||||
105 <isMarriedTo>[25-27]
|
||||
106 <isMarriedTo>[27-29]
|
||||
107 <isMarriedTo>[29-32]
|
||||
108 <isMarriedTo>[32-34]
|
||||
109 <isMarriedTo>[34-38]
|
||||
110 <isMarriedTo>[38-42]
|
||||
111 <isMarriedTo>[42-47]
|
||||
112 <isMarriedTo>[47-51]
|
||||
113 <isMarriedTo>[51-55]
|
||||
114 <isMarriedTo>[59-59]
|
||||
115 <owns>[0-10]
|
||||
116 <owns>[10-17]
|
||||
117 <owns>[17-19]
|
||||
118 <owns>[19-23]
|
||||
119 <owns>[23-36]
|
||||
120 <owns>[36-38]
|
||||
121 <owns>[59-59]
|
||||
122 <graduatedFrom>[0-3]
|
||||
123 <graduatedFrom>[3-5]
|
||||
124 <graduatedFrom>[5-7]
|
||||
125 <graduatedFrom>[7-10]
|
||||
126 <graduatedFrom>[10-14]
|
||||
127 <graduatedFrom>[14-17]
|
||||
128 <graduatedFrom>[17-19]
|
||||
129 <graduatedFrom>[19-21]
|
||||
130 <graduatedFrom>[21-23]
|
||||
131 <graduatedFrom>[23-27]
|
||||
132 <graduatedFrom>[27-32]
|
||||
133 <graduatedFrom>[32-34]
|
||||
134 <graduatedFrom>[34-38]
|
||||
135 <graduatedFrom>[38-42]
|
||||
136 <graduatedFrom>[59-59]
|
||||
137 <isAffiliatedTo>[1-4]
|
||||
138 <isAffiliatedTo>[4-6]
|
||||
139 <isAffiliatedTo>[6-8]
|
||||
140 <isAffiliatedTo>[8-11]
|
||||
141 <isAffiliatedTo>[11-13]
|
||||
142 <isAffiliatedTo>[13-15]
|
||||
143 <isAffiliatedTo>[15-18]
|
||||
144 <isAffiliatedTo>[18-20]
|
||||
145 <isAffiliatedTo>[20-22]
|
||||
146 <isAffiliatedTo>[22-24]
|
||||
147 <isAffiliatedTo>[24-26]
|
||||
148 <isAffiliatedTo>[26-28]
|
||||
149 <isAffiliatedTo>[28-30]
|
||||
150 <isAffiliatedTo>[30-33]
|
||||
151 <isAffiliatedTo>[33-35]
|
||||
152 <isAffiliatedTo>[35-37]
|
||||
153 <isAffiliatedTo>[37-40]
|
||||
154 <isAffiliatedTo>[40-42]
|
||||
155 <isAffiliatedTo>[42-44]
|
||||
156 <isAffiliatedTo>[44-47]
|
||||
157 <isAffiliatedTo>[47-49]
|
||||
158 <isAffiliatedTo>[49-51]
|
||||
159 <isAffiliatedTo>[51-53]
|
||||
160 <isAffiliatedTo>[53-55]
|
||||
161 <isAffiliatedTo>[55-57]
|
||||
162 <isAffiliatedTo>[59-59]
|
||||
163 <created>[0-3]
|
||||
164 <created>[3-5]
|
||||
165 <created>[5-10]
|
||||
166 <created>[10-12]
|
||||
167 <created>[12-17]
|
||||
168 <created>[17-19]
|
||||
169 <created>[19-25]
|
||||
170 <created>[25-29]
|
||||
171 <created>[29-32]
|
||||
172 <created>[32-36]
|
||||
173 <created>[36-42]
|
||||
174 <created>[42-47]
|
||||
175 <created>[47-53]
|
||||
176 <created>[59-59]
|
6909
data/yago11k/test.txt
Normal file
6909
data/yago11k/test.txt
Normal file
File diff suppressed because it is too large
Load Diff
60
data/yago11k/time_map.dict
Normal file
60
data/yago11k/time_map.dict
Normal file
@ -0,0 +1,60 @@
|
||||
0 -431 1782
|
||||
1 1783 1848
|
||||
2 1849 1870
|
||||
3 1871 1888
|
||||
4 1889 1899
|
||||
5 1900 1906
|
||||
6 1907 1912
|
||||
7 1913 1917
|
||||
8 1918 1922
|
||||
9 1923 1926
|
||||
10 1927 1930
|
||||
11 1931 1934
|
||||
12 1935 1938
|
||||
13 1939 1941
|
||||
14 1942 1944
|
||||
15 1945 1947
|
||||
16 1948 1950
|
||||
17 1951 1953
|
||||
18 1954 1956
|
||||
19 1957 1959
|
||||
20 1960 1962
|
||||
21 1963 1965
|
||||
22 1966 1967
|
||||
23 1968 1969
|
||||
24 1970 1971
|
||||
25 1972 1973
|
||||
26 1974 1975
|
||||
27 1976 1977
|
||||
28 1978 1979
|
||||
29 1980 1981
|
||||
30 1982 1983
|
||||
31 1984 1985
|
||||
32 1986 1987
|
||||
33 1988 1989
|
||||
34 1990 1991
|
||||
35 1992 1993
|
||||
36 1994 1994
|
||||
37 1995 1996
|
||||
38 1997 1997
|
||||
39 1998 1998
|
||||
40 1999 1999
|
||||
41 2000 2000
|
||||
42 2001 2001
|
||||
43 2002 2002
|
||||
44 2003 2003
|
||||
45 2004 2004
|
||||
46 2005 2005
|
||||
47 2006 2006
|
||||
48 2007 2007
|
||||
49 2008 2008
|
||||
50 2009 2009
|
||||
51 2010 2010
|
||||
52 2011 2011
|
||||
53 2012 2012
|
||||
54 2013 2013
|
||||
55 2014 2014
|
||||
56 2015 2015
|
||||
57 2016 2016
|
||||
58 2017 2017
|
||||
59 2018 2018
|
63925
data/yago11k/train.txt
Normal file
63925
data/yago11k/train.txt
Normal file
File diff suppressed because it is too large
Load Diff
7198
data/yago11k/valid.txt
Normal file
7198
data/yago11k/valid.txt
Normal file
File diff suppressed because it is too large
Load Diff
20
main.py
20
main.py
@ -80,9 +80,18 @@ class Main(object):
|
||||
ent_set.add(sub)
|
||||
rel_set.add(rel)
|
||||
ent_set.add(obj)
|
||||
|
||||
self.ent2id = {}
|
||||
for line in open('./data/{}/{}'.format(self.p.dataset, "entities.dict")):
|
||||
id, ent = map(str.lower, line.strip().split('\t'))
|
||||
self.ent2id[ent] = int(id)
|
||||
self.rel2id = {}
|
||||
for line in open('./data/{}/{}'.format(self.p.dataset, "relations.dict")):
|
||||
id, rel = map(str.lower, line.strip().split('\t'))
|
||||
self.rel2id[rel] = int(id)
|
||||
|
||||
self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
|
||||
self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
|
||||
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
|
||||
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
|
||||
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
|
||||
for idx, rel in enumerate(rel_set)})
|
||||
|
||||
@ -569,9 +578,9 @@ if __name__ == "__main__":
|
||||
help='Dropout for Feature. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
||||
parser.add_argument('--inp_drop', dest="inp_drop", default=0.2, type=float,
|
||||
help='Dropout for Input layer. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
||||
parser.add_argument('--drop_path', dest="drop_path", default=0.1, type=float,
|
||||
parser.add_argument('--drop_path', dest="drop_path", default=0.0, type=float,
|
||||
help='Path dropout. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
||||
parser.add_argument('--drop', dest="drop", default=0.2, type=float,
|
||||
parser.add_argument('--drop', dest="drop", default=0.0, type=float,
|
||||
help='Inner drop. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
||||
|
||||
# Configuration for in/output channels for ConvE, HypER, HypE
|
||||
@ -616,6 +625,7 @@ if __name__ == "__main__":
|
||||
default='./config/', help='Config directory')
|
||||
|
||||
parser.add_argument('--test_only', action='store_true', default=False)
|
||||
parser.add_argument('--filtered', action='store_true', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -630,4 +640,4 @@ if __name__ == "__main__":
|
||||
model.load_model(save_path)
|
||||
model.evaluate('test')
|
||||
else:
|
||||
model.fit()
|
||||
model.fit()
|
||||
|
401
pvt.py
401
pvt.py
@ -1,401 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.vision_transformer import _cfg
|
||||
import math
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.linear = linear
|
||||
if self.linear:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
if self.linear:
|
||||
x = self.relu(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.linear = linear
|
||||
self.sr_ratio = sr_ratio
|
||||
if not linear:
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
else:
|
||||
self.pool = nn.AdaptiveAvgPool2d(7)
|
||||
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.act = nn.GELU()
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
if not self.linear:
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
x_ = self.act(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert max(patch_size) > stride, "Set larger patch_size than stride"
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.H, self.W = img_size[0] // stride, img_size[1] // stride
|
||||
self.num_patches = self.H * self.W
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class PyramidVisionTransformerV2(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
||||
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
|
||||
patch_size=7 if i == 0 else 3,
|
||||
stride=4 if i == 0 else 2,
|
||||
in_chans=in_chans if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i])
|
||||
|
||||
block = nn.ModuleList([Block(
|
||||
dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[i], linear=linear)
|
||||
for j in range(depths[i])])
|
||||
norm = norm_layer(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
||||
setattr(self, f"block{i + 1}", block)
|
||||
setattr(self, f"norm{i + 1}", norm)
|
||||
|
||||
# classification head
|
||||
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def freeze_patch_emb(self):
|
||||
self.patch_embed1.requires_grad = False
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f"patch_embed{i + 1}")
|
||||
block = getattr(self, f"block{i + 1}")
|
||||
norm = getattr(self, f"norm{i + 1}")
|
||||
x, H, W = patch_embed(x)
|
||||
for blk in block:
|
||||
x = blk(x, H, W)
|
||||
x = norm(x)
|
||||
if i != self.num_stages - 1:
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return x.mean(dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, dim=768):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _conv_filter(state_dict, patch_size=16):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b0(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b1(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b2(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b3(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b4(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b5(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b2_li(pretrained=False, **kwargs):
|
||||
model = PyramidVisionTransformerV2(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], linear=True, **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
return model
|
Loading…
Reference in New Issue
Block a user