Use new Content-Type header value parser. (#302)

This commit is contained in:
Eugene Kabanov 2022-08-05 19:59:26 +03:00 committed by GitHub
parent 79c51914ae
commit 939195626f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 81 deletions

View File

@ -157,6 +157,7 @@ type
contentEncoding*: set[ContentEncodingFlags] contentEncoding*: set[ContentEncodingFlags]
transferEncoding*: set[TransferEncodingFlags] transferEncoding*: set[TransferEncodingFlags]
contentLength*: uint64 contentLength*: uint64
contentType*: Opt[ContentTypeData]
HttpClientResponseRef* = ref HttpClientResponse HttpClientResponseRef* = ref HttpClientResponse
@ -783,13 +784,25 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
else: else:
false false
let contentType =
block:
let list = headers.getList(ContentTypeHeader)
if len(list) > 0:
let res = getContentType(list)
if res.isErr():
return err("Invalid headers received, invalid `Content-Type`")
else:
Opt.some(res.get())
else:
Opt.none(ContentTypeData)
let res = HttpClientResponseRef( let res = HttpClientResponseRef(
state: HttpReqRespState.Open, status: resp.code, state: HttpReqRespState.Open, status: resp.code,
address: request.address, requestMethod: request.meth, address: request.address, requestMethod: request.meth,
reason: resp.reason(data), version: resp.version, session: request.session, reason: resp.reason(data), version: resp.version, session: request.session,
connection: request.connection, headers: headers, connection: request.connection, headers: headers,
contentEncoding: contentEncoding, transferEncoding: transferEncoding, contentEncoding: contentEncoding, transferEncoding: transferEncoding,
contentLength: contentLength, bodyFlag: bodyFlag contentLength: contentLength, contentType: contentType, bodyFlag: bodyFlag
) )
res.connection.state = HttpClientConnectionState.ResponseHeadersReceived res.connection.state = HttpClientConnectionState.ResponseHeadersReceived
if nobodyFlag: if nobodyFlag:

View File

@ -32,8 +32,8 @@ const
LocationHeader* = "location" LocationHeader* = "location"
AuthorizationHeader* = "authorization" AuthorizationHeader* = "authorization"
UrlEncodedContentType* = "application/x-www-form-urlencoded" UrlEncodedContentType* = MediaType.init("application/x-www-form-urlencoded")
MultipartContentType* = "multipart/form-data" MultipartContentType* = MediaType.init("multipart/form-data")
type type
HttpResult*[T] = Result[T, string] HttpResult*[T] = Result[T, string]
@ -193,7 +193,7 @@ func getContentEncoding*(ch: openArray[string]): HttpResult[
return err("Incorrect Content-Encoding value") return err("Incorrect Content-Encoding value")
ok(res) ok(res)
func getContentType*(ch: openArray[string]): HttpResult[string] {. func getContentType*(ch: openArray[string]): HttpResult[ContentTypeData] {.
raises: [Defect].} = raises: [Defect].} =
## Check and prepare value of ``Content-Type`` header. ## Check and prepare value of ``Content-Type`` header.
if len(ch) == 0: if len(ch) == 0:
@ -201,8 +201,10 @@ func getContentType*(ch: openArray[string]): HttpResult[string] {.
elif len(ch) > 1: elif len(ch) > 1:
err("Multiple Content-Type values found") err("Multiple Content-Type values found")
else: else:
let mparts = ch[0].split(";") let res = getContentType(ch[0])
ok(strip(mparts[0]).toLowerAscii()) if res.isErr():
return err($res.error())
ok(res.get())
proc bytesToString*(src: openArray[byte], dst: var openArray[char]) = proc bytesToString*(src: openArray[byte], dst: var openArray[char]) =
## Convert array of bytes to array of characters. ## Convert array of bytes to array of characters.

View File

@ -98,6 +98,7 @@ type
transferEncoding*: set[TransferEncodingFlags] transferEncoding*: set[TransferEncodingFlags]
requestFlags*: set[HttpRequestFlags] requestFlags*: set[HttpRequestFlags]
contentLength: int contentLength: int
contentTypeData*: Option[ContentTypeData]
connection*: HttpConnectionRef connection*: HttpConnectionRef
response*: Option[HttpResponseRef] response*: Option[HttpResponseRef]
@ -324,9 +325,10 @@ proc prepareRequest(conn: HttpConnectionRef,
# steps to reveal information about body. # steps to reveal information about body.
if ContentLengthHeader in request.headers: if ContentLengthHeader in request.headers:
let length = request.headers.getInt(ContentLengthHeader) let length = request.headers.getInt(ContentLengthHeader)
if length > 0: if length >= 0:
if request.meth == MethodTrace: if request.meth == MethodTrace:
return err(Http400) return err(Http400)
# Because of coversion to `int` we should avoid unexpected OverflowError.
if length > uint64(high(int)): if length > uint64(high(int)):
return err(Http413) return err(Http413)
if length > uint64(conn.server.maxRequestBodySize): if length > uint64(conn.server.maxRequestBodySize):
@ -342,12 +344,14 @@ proc prepareRequest(conn: HttpConnectionRef,
if request.hasBody(): if request.hasBody():
# If request has body, we going to understand how its encoded. # If request has body, we going to understand how its encoded.
if ContentTypeHeader in request.headers: if ContentTypeHeader in request.headers:
let contentType = request.headers.getString(ContentTypeHeader) let contentType =
let tmp = strip(contentType).toLowerAscii() getContentType(request.headers.getList(ContentTypeHeader)).valueOr:
if tmp.startsWith(UrlEncodedContentType): return err(Http415)
if contentType == UrlEncodedContentType:
request.requestFlags.incl(HttpRequestFlags.UrlencodedForm) request.requestFlags.incl(HttpRequestFlags.UrlencodedForm)
elif tmp.startsWith(MultipartContentType): elif contentType == MultipartContentType:
request.requestFlags.incl(HttpRequestFlags.MultipartForm) request.requestFlags.incl(HttpRequestFlags.MultipartForm)
request.contentTypeData = some(contentType)
if ExpectHeader in request.headers: if ExpectHeader in request.headers:
let expectHeader = request.headers.getString(ExpectHeader) let expectHeader = request.headers.getString(ExpectHeader)
@ -899,19 +903,17 @@ proc join*(server: HttpServerRef): Future[void] =
retFuture retFuture
proc getMultipartReader*(req: HttpRequestRef): HttpResult[MultiPartReaderRef] = proc getMultipartReader*(req: HttpRequestRef): HttpResult[MultiPartReaderRef] {.
raises: [Defect].} =
## Create new MultiPartReader interface for specific request. ## Create new MultiPartReader interface for specific request.
if req.meth in PostMethods: if req.meth in PostMethods:
if MultipartForm in req.requestFlags: if MultipartForm in req.requestFlags:
let ctype = ? getContentType(req.headers.getList(ContentTypeHeader)) if req.contentTypeData.isSome():
if ctype != MultipartContentType: let boundary = ? getMultipartBoundary(req.contentTypeData.get())
err("Content type is not supported")
else:
let boundary = ? getMultipartBoundary(
req.headers.getList(ContentTypeHeader)
)
var stream = ? req.getBodyReader() var stream = ? req.getBodyReader()
ok(MultiPartReaderRef.new(stream, boundary)) ok(MultiPartReaderRef.new(stream, boundary))
else:
err("Content type is missing or invalid")
else: else:
err("Request's data is not multipart encoded") err("Request's data is not multipart encoded")
else: else:

View File

@ -8,11 +8,11 @@
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import std/[monotimes, strutils] import std/[monotimes, strutils]
import stew/results import stew/results, httputils
import ../../asyncloop import ../../asyncloop
import ../../streams/[asyncstream, boundstream, chunkstream] import ../../streams/[asyncstream, boundstream, chunkstream]
import httptable, httpcommon, httpbodyrw import httptable, httpcommon, httpbodyrw
export asyncloop, httptable, httpcommon, httpbodyrw, asyncstream export asyncloop, httptable, httpcommon, httpbodyrw, asyncstream, httputils
const const
UnableToReadMultipartBody = "Unable to read multipart message body" UnableToReadMultipartBody = "Unable to read multipart message body"
@ -439,55 +439,25 @@ func validateBoundary[B: BChar](boundary: openArray[B]): HttpResult[void] =
return err("Content-Type boundary alphabet incorrect") return err("Content-Type boundary alphabet incorrect")
ok() ok()
func getMultipartBoundary*(ch: openArray[string]): HttpResult[string] {. func getMultipartBoundary*(contentData: ContentTypeData): HttpResult[string] {.
raises: [Defect].} = raises: [Defect].} =
## Returns ``multipart/form-data`` boundary value from ``Content-Type`` ## Returns ``multipart/form-data`` boundary value from ``Content-Type``
## header. ## header.
## ##
## The procedure carries out all the necessary checks: ## The procedure carries out all the necessary checks:
## 1) There should be single `Content-Type` header value in headers. ## 1) `boundary` value must be present.
## 2) `Content-Type` must be ``multipart/form-data``. ## 2) `boundary` value must be less then 70 characters length and
## 3) `boundary` value must be present
## 4) `boundary` value must be less then 70 characters length and
## all characters should be part of specific alphabet. ## all characters should be part of specific alphabet.
if len(ch) > 1: let candidate =
err("Multiple Content-Type headers found") block:
else: var res: string
if len(ch) == 0: for item in contentData.params:
err("Content-Type header is missing") if cmpIgnoreCase(item.name, "boundary") == 0:
else: res = item.value
if len(ch[0]) == 0: break
return err("Content-Type header has empty value") res
let mparts = ch[0].split(";") ? validateBoundary(candidate)
if strip(mparts[0]).toLowerAscii() != "multipart/form-data": ok(candidate)
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let index =
block:
var idx = 0
for i in 1 ..< len(mparts):
let stripped = strip(mparts[i])
if stripped.toLowerAscii().startsWith("boundary="):
idx = i
break
idx
if index == 0:
err("Missing Content-Type boundary key")
else:
let stripped = strip(mparts[index])
let bparts = stripped.split("=", 1)
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
let candidate = strip(bparts[1])
let res = validateBoundary(candidate)
if res.isErr():
err($res.error())
else:
ok(candidate)
proc quoteCheck(name: string): HttpResult[string] = proc quoteCheck(name: string): HttpResult[string] =
if len(name) > 0: if len(name) > 0:

View File

@ -610,39 +610,50 @@ suite "HTTP server testing suite":
"--------------------------------------------------; charset=UTF-8", "--------------------------------------------------; charset=UTF-8",
"-----------------------------------------------------------------" & "-----------------------------------------------------------------" &
"-----"), "-----"),
("multipart/form-data; boundary=ABCDEFGHIJKLMNOPQRST" & ("multipart/form-data; boundary=\"ABCDEFGHIJKLMNOPQRST" &
"UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+_,-.; charset=UTF-8", "UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+_,-.\"; charset=UTF-8",
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" & "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" &
"+_,-."), "+_,-."),
("multipart/form-data; boundary=ABCDEFGHIJKLMNOPQRST" & ("multipart/form-data; boundary=\"ABCDEFGHIJKLMNOPQRST" &
"UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+?=:/; charset=UTF-8", "UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+?=:/\"; charset=UTF-8",
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" & "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" &
"+?=:/"), "+?=:/"),
("multipart/form-data; charset=UTF-8; boundary=ABCDEFGHIJKLMNOPQRST" & ("multipart/form-data; charset=UTF-8; boundary=\"ABCDEFGHIJKLMNOPQRST" &
"UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+_,-.", "UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+_,-.\"",
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" & "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" &
"+_,-."), "+_,-."),
("multipart/form-data; charset=UTF-8; boundary=ABCDEFGHIJKLMNOPQRST" & ("multipart/form-data; charset=UTF-8; boundary=\"ABCDEFGHIJKLMNOPQRST" &
"UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+?=:/", "UVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()+?=:/\"",
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" & "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'()" &
"+?=:/") "+?=:/"),
("multipart/form-data; charset=UTF-8; boundary=0123456789ABCDEFGHIJKL" &
"MNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+-",
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+-"),
("multipart/form-data; boundary=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZa" &
"bcdefghijklmnopqrstuvwxyz+-; charset=UTF-8",
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+-")
] ]
proc performCheck(ch: openArray[string]): HttpResult[string] =
let cdata = ? getContentType(ch)
if cdata.mediaType != MediaType.init("multipart/form-data"):
return err("Invalid media type")
getMultipartBoundary(cdata)
for i in 0 ..< 256: for i in 0 ..< 256:
let boundary = "multipart/form-data; boundary=" & $char(i) let boundary = "multipart/form-data; boundary=\"" & $char(i) & "\""
if char(i) in AllowedCharacters: if char(i) in AllowedCharacters:
check getMultipartBoundary([boundary]).isOk() check performCheck([boundary]).isOk()
else: else:
check getMultipartBoundary([boundary]).isErr() check performCheck([boundary]).isErr()
check: check:
getMultipartBoundary([]).isErr() performCheck([]).isErr()
getMultipartBoundary(["multipart/form-data; boundary=A", performCheck(["multipart/form-data; boundary=A",
"multipart/form-data; boundary=B"]).isErr() "multipart/form-data; boundary=B"]).isErr()
for item in FailureVectors: for item in FailureVectors:
check getMultipartBoundary([item]).isErr() check performCheck([item]).isErr()
for item in SuccessVectors: for item in SuccessVectors:
let res = getMultipartBoundary([item[0]]) let res = performCheck([item[0]])
check: check:
res.isOk() res.isOk()
item[1] == res.get() item[1] == res.get()