45
45
import static io .netty .handler .codec .http2 .Http2Error .INTERNAL_ERROR ;
46
46
import static io .netty .handler .codec .http2 .Http2Error .PROTOCOL_ERROR ;
47
47
import static io .netty .handler .codec .http2 .Http2Exception .connectionError ;
48
+ import static io .netty .handler .codec .http2 .Http2Headers .PseudoHeaderName .getPseudoHeader ;
49
+ import static io .netty .handler .codec .http2 .Http2Headers .PseudoHeaderName .hasPseudoHeaderFormat ;
48
50
import static io .netty .util .AsciiString .EMPTY_STRING ;
49
51
import static io .netty .util .internal .ObjectUtil .checkPositive ;
50
52
import static io .netty .util .internal .ThrowableUtil .unknownStackTrace ;
@@ -119,14 +121,15 @@ final class HpackDecoder {
119
121
* <p>
120
122
* This method assumes the entire header block is contained in {@code in}.
121
123
*/
122
- public void decode (int streamId , ByteBuf in , Http2Headers headers ) throws Http2Exception {
124
+ public void decode (int streamId , ByteBuf in , Http2Headers headers , boolean validateHeaders ) throws Http2Exception {
123
125
int index = 0 ;
124
126
long headersLength = 0 ;
125
127
int nameLength = 0 ;
126
128
int valueLength = 0 ;
127
129
byte state = READ_HEADER_REPRESENTATION ;
128
130
boolean huffmanEncoded = false ;
129
131
CharSequence name = null ;
132
+ HeaderType headerType = null ;
130
133
IndexType indexType = IndexType .NONE ;
131
134
while (in .isReadable ()) {
132
135
switch (state ) {
@@ -146,7 +149,10 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
146
149
state = READ_INDEXED_HEADER ;
147
150
break ;
148
151
default :
149
- headersLength = indexHeader (index , headers , headersLength );
152
+ HpackHeaderField indexedHeader = getIndexedHeader (index );
153
+ headerType = validate (indexedHeader .name , headerType , validateHeaders );
154
+ headersLength = addHeader (headers , indexedHeader .name , indexedHeader .value ,
155
+ headersLength );
150
156
}
151
157
} else if ((b & 0x40 ) == 0x40 ) {
152
158
// Literal Header Field with Incremental Indexing
@@ -162,6 +168,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
162
168
default :
163
169
// Index was stored as the prefix
164
170
name = readName (index );
171
+ headerType = validate (name , headerType , validateHeaders );
165
172
nameLength = name .length ();
166
173
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX ;
167
174
}
@@ -188,6 +195,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
188
195
default :
189
196
// Index was stored as the prefix
190
197
name = readName (index );
198
+ headerType = validate (name , headerType , validateHeaders );
191
199
nameLength = name .length ();
192
200
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX ;
193
201
}
@@ -200,13 +208,16 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
200
208
break ;
201
209
202
210
case READ_INDEXED_HEADER :
203
- headersLength = indexHeader (decodeULE128 (in , index ), headers , headersLength );
211
+ HpackHeaderField indexedHeader = getIndexedHeader (decodeULE128 (in , index ));
212
+ headerType = validate (indexedHeader .name , headerType , validateHeaders );
213
+ headersLength = addHeader (headers , indexedHeader .name , indexedHeader .value , headersLength );
204
214
state = READ_HEADER_REPRESENTATION ;
205
215
break ;
206
216
207
217
case READ_INDEXED_HEADER_NAME :
208
218
// Header Name matches an entry in the Header Table
209
219
name = readName (decodeULE128 (in , index ));
220
+ headerType = validate (name , headerType , validateHeaders );
210
221
nameLength = name .length ();
211
222
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX ;
212
223
break ;
@@ -243,6 +254,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
243
254
}
244
255
245
256
name = readStringLiteral (in , nameLength , huffmanEncoded );
257
+ headerType = validate (name , headerType , validateHeaders );
246
258
247
259
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX ;
248
260
break ;
@@ -256,6 +268,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
256
268
state = READ_LITERAL_HEADER_VALUE_LENGTH ;
257
269
break ;
258
270
case 0 :
271
+ headerType = validate (name , headerType , validateHeaders );
259
272
headersLength = insertHeader (headers , name , EMPTY_STRING , indexType , headersLength );
260
273
state = READ_HEADER_REPRESENTATION ;
261
274
break ;
@@ -288,6 +301,7 @@ public void decode(int streamId, ByteBuf in, Http2Headers headers) throws Http2E
288
301
}
289
302
290
303
CharSequence value = readStringLiteral (in , valueLength , huffmanEncoded );
304
+ headerType = validate (name , headerType , validateHeaders );
291
305
headersLength = insertHeader (headers , name , value , indexType , headersLength );
292
306
state = READ_HEADER_REPRESENTATION ;
293
307
break ;
@@ -386,6 +400,34 @@ private void setDynamicTableSize(long dynamicTableSize) throws Http2Exception {
386
400
hpackDynamicTable .setCapacity (dynamicTableSize );
387
401
}
388
402
403
+ private HeaderType validate (CharSequence name , HeaderType previousHeaderType ,
404
+ final boolean validateHeaders ) throws Http2Exception {
405
+ if (!validateHeaders ) {
406
+ return null ;
407
+ }
408
+
409
+ if (hasPseudoHeaderFormat (name )) {
410
+ if (previousHeaderType == HeaderType .REGULAR_HEADER ) {
411
+ throw connectionError (PROTOCOL_ERROR , "Pseudo-header field '%s' found after regular header." , name );
412
+ }
413
+
414
+ final Http2Headers .PseudoHeaderName pseudoHeader = getPseudoHeader (name );
415
+ if (pseudoHeader == null ) {
416
+ throw connectionError (PROTOCOL_ERROR , "Invalid HTTP/2 pseudo-header '%s' encountered." , name );
417
+ }
418
+
419
+ final HeaderType currentHeaderType = pseudoHeader .isRequestOnly () ?
420
+ HeaderType .REQUEST_PSEUDO_HEADER : HeaderType .RESPONSE_PSEUDO_HEADER ;
421
+ if (previousHeaderType != null && currentHeaderType != previousHeaderType ) {
422
+ throw connectionError (PROTOCOL_ERROR , "Mix of request and response pseudo-headers." );
423
+ }
424
+
425
+ return currentHeaderType ;
426
+ }
427
+
428
+ return HeaderType .REGULAR_HEADER ;
429
+ }
430
+
389
431
private CharSequence readName (int index ) throws Http2Exception {
390
432
if (index <= HpackStaticTable .length ) {
391
433
HpackHeaderField hpackHeaderField = HpackStaticTable .getEntry (index );
@@ -398,14 +440,12 @@ private CharSequence readName(int index) throws Http2Exception {
398
440
throw READ_NAME_ILLEGAL_INDEX_VALUE ;
399
441
}
400
442
401
- private long indexHeader (int index , Http2Headers headers , long headersLength ) throws Http2Exception {
443
+ private HpackHeaderField getIndexedHeader (int index ) throws Http2Exception {
402
444
if (index <= HpackStaticTable .length ) {
403
- HpackHeaderField hpackHeaderField = HpackStaticTable .getEntry (index );
404
- return addHeader (headers , hpackHeaderField .name , hpackHeaderField .value , headersLength );
445
+ return HpackStaticTable .getEntry (index );
405
446
}
406
447
if (index - HpackStaticTable .length <= hpackDynamicTable .length ()) {
407
- HpackHeaderField hpackHeaderField = hpackDynamicTable .getEntry (index - HpackStaticTable .length );
408
- return addHeader (headers , hpackHeaderField .name , hpackHeaderField .value , headersLength );
448
+ return hpackDynamicTable .getEntry (index - HpackStaticTable .length );
409
449
}
410
450
throw INDEX_HEADER_ILLEGAL_INDEX_VALUE ;
411
451
}
@@ -504,4 +544,13 @@ static long decodeULE128(ByteBuf in, long result) throws Http2Exception {
504
544
505
545
throw DECODE_ULE_128_DECOMPRESSION_EXCEPTION ;
506
546
}
547
+
548
+ /**
549
+ * HTTP/2 header types.
550
+ */
551
+ private enum HeaderType {
552
+ REGULAR_HEADER ,
553
+ REQUEST_PSEUDO_HEADER ,
554
+ RESPONSE_PSEUDO_HEADER
555
+ }
507
556
}
0 commit comments