Skip to content

Commit

Permalink
static isSecurity
Browse files Browse the repository at this point in the history
  • Loading branch information
bung87 committed May 28, 2023
1 parent d66beae commit 410054f
Showing 1 changed file with 61 additions and 58 deletions.
119 changes: 61 additions & 58 deletions src/scorper/http/streamserver.nim
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,8 @@ proc defaultErrorHandle(req: ImpRequest, err: ref Exception | HttpError; headers
headers.ContentType "text/plain"
await req.respError(Http400, err.msg, headers)

template tryHandle(body: untyped, keep: var bool) =
template tryHandle(req: ImpRequest, body: untyped, keep: var bool) =
mixin defaultErrorHandle
try:
await wait(body, TimeOut.seconds)
except AsyncTimeoutError:
Expand All @@ -673,6 +674,7 @@ template tryHandle(body: untyped, keep: var bool) =
proc processRequest(
scorper: Scorper,
req: ImpRequest,
isSecurity: static[bool]
): Future[bool] {.async.} =
req.responded = false
req.parsed = false
Expand All @@ -694,12 +696,12 @@ proc processRequest(
await req.respStatus(Http400, e.msg)
return false
# Headers
let headerEnd = req.server.httpParser.parseHeader(addr req.buf[0], req.buf.len)
let headerEnd = scorper.httpParser.parseHeader(addr req.buf[0], req.buf.len)
if headerEnd == -1:
await req.respError(Http400)
return true
req.server.httpParser.toHttpHeaders(req.headers)
case req.server.httpParser.getMethod
scorper.httpParser.toHttpHeaders(req.headers)
case scorper.httpParser.getMethod
of "GET": req.meth = HttpGet
of "POST": req.meth = HttpPost
of "HEAD": req.meth = HttpHead
Expand All @@ -713,21 +715,21 @@ proc processRequest(
await req.respError(Http501)
return true

req.path = req.server.httpParser.getPath()
req.path = scorper.httpParser.getPath()
try:
req.url = parseUrl("http://" & (if req.server.isSecurity: "s" else: "") & req.hostname & req.path)[]
req.url = parseUrl("http://" & (when isSecurity: "s" else: "") & req.hostname & req.path)[]
except ValueError as e:
req.server.logSub.next(e.msg)
scorper.logSub.next(e.msg)
asyncSpawn req.respError(Http400)
return true
case req.server.httpParser.major[]:
case scorper.httpParser.major[]:
of '1':
req.protocol.major = 1
of '2':
req.protocol.major = 2
else:
discard
case req.server.httpParser.minor[]:
case scorper.httpParser.minor[]:
of '0':
req.protocol.minor = 0
of '1':
Expand Down Expand Up @@ -772,7 +774,7 @@ proc processRequest(
else:
shallowCopy(req.query, req.url.query)
handlePreProcessMiddlewares(req)
tryHandle(scorper.callback(req), keep)
tryHandle(req, scorper.callback(req), keep)
if not keep:
return false
discard await postCheck(req)
Expand All @@ -786,7 +788,7 @@ proc processRequest(
shallowCopy(req.query, req.url.query)
req.prefix = matched.route.prefix
handlePreProcessMiddlewares(req)
tryHandle(matched.handler(req), keep)
tryHandle(req, matched.handler(req), keep)
if not keep:
return false
discard await postCheck(req)
Expand All @@ -812,51 +814,52 @@ proc processRequest(
else:
return false

template processClientImpl(isSecurity: bool): proc (server: StreamServer, transp: StreamTransport) {.async.} =
proc processClient(server: StreamServer, transp: StreamTransport) {.async.} =
var req = ImpRequest()
req.headers = newHttpHeaders()
when defined(gcArc) or defined(gcOrc):
req.server = cast[Scorper](server)
req.transp = transp
else:
shallowCopy(req.server, cast[Scorper](server))
shallowCopy(req.transp, transp)
try:
req.hostname = $req.transp.localAddress
except TransportError:
discard
try:
req.ip = $req.transp.remoteAddress
except TransportError:
discard

proc processClient(server: StreamServer, transp: StreamTransport) {.async.} =
var req = ImpRequest()
req.headers = newHttpHeaders()
when defined(gcArc) or defined(gcOrc):
req.server = cast[Scorper](server)
req.transp = transp
else:
shallowCopy(req.server, cast[Scorper](server))
shallowCopy(req.transp, transp)
try:
req.hostname = $req.transp.localAddress
except TransportError:
discard
try:
req.ip = $req.transp.remoteAddress
except TransportError:
discard

when defined(ssl):
if req.server.isSecurity:
req.tlsStream =
newTLSServerAsyncStream(req.transp.newAsyncStreamReader, req.transp.newAsyncStreamWriter,
req.server.tlsPrivateKey,
req.server.tlsCertificate,
minVersion = req.server.tlsMinVersion,
maxVersion = req.server.tlsMaxVersion,
flags = req.server.secureFlags)
req.reader = req.tlsStream.reader
req.writer = req.tlsStream.writer
await handshake(req.tlsStream)
when defined(ssl):
when isSecurity:
req.tlsStream =
newTLSServerAsyncStream(req.transp.newAsyncStreamReader, req.transp.newAsyncStreamWriter,
req.server.tlsPrivateKey,
req.server.tlsCertificate,
minVersion = req.server.tlsMinVersion,
maxVersion = req.server.tlsMaxVersion,
flags = req.server.secureFlags)
req.reader = req.tlsStream.reader
req.writer = req.tlsStream.writer
await handshake(req.tlsStream)
else:
req.reader = req.transp.newAsyncStreamReader
req.writer = req.transp.newAsyncStreamWriter
else:
req.reader = req.transp.newAsyncStreamReader
req.writer = req.transp.newAsyncStreamWriter
else:
req.reader = req.transp.newAsyncStreamReader
req.writer = req.transp.newAsyncStreamWriter
while not transp.atEof():
let retry = await processRequest(req.server, req)
handlePostProcessMiddlewares(req)
if not retry:
await req.reader.closeWait
await req.writer.closeWait
await transp.closeWait
break
while not transp.atEof():
let retry = await processRequest(req.server, req, isSecurity)
handlePostProcessMiddlewares(req)
if not retry:
await req.reader.closeWait
await req.writer.closeWait
await transp.closeWait
break
processClient

proc logSubOnNext(v: string) =
echo v
Expand Down Expand Up @@ -904,7 +907,7 @@ proc serve*(address: string,
callback: ScorperCallback,
flags: set[ServerFlags] = {ReuseAddr},
maxBody = 8.Mb,
isSecurity = false,
isSecurity: static[bool] = false,
privateKey: string = "",
certificate: string = "",
secureFlags: set[TLSFlags] = {},
Expand All @@ -917,10 +920,10 @@ proc serve*(address: string,
server.callback = callback
server.maxBody = maxBody
let address = initTAddress(address)
server = cast[Scorper](createStreamServer(address, processClient, flags, child = cast[StreamServer](server)))
server = cast[Scorper](createStreamServer(address, processClientImpl(isSecurity), flags, child = cast[StreamServer](server)))
server.initScorper()
when defined(ssl):
if isSecurity:
when isSecurity:
server.initSecurityScorper(secureFlags, privateKey, certificate, tlsMinVersion, tlsMaxVersion)
server.start()

Expand All @@ -936,7 +939,7 @@ proc setHandler*(self: Scorper, handler: ScorperCallback) {.inline, raises: [].}
proc newScorper*(address: string, handler: ScorperCallback | Router[ScorperCallback] = default(ScorperCallback),
flags: set[ServerFlags] = {ReuseAddr},
maxBody = 8.Mb,
isSecurity = false,
isSecurity: static[bool] = false,
privateKey: string = "",
certificate: string = "",
secureFlags: set[TLSFlags] = {},
Expand All @@ -953,10 +956,10 @@ proc newScorper*(address: string, handler: ScorperCallback | Router[ScorperCallb
result.mimeDb = newScorperMimetypes()
result.maxBody = maxBody
let address = initTAddress(address)
result = cast[Scorper](createStreamServer(address, processClient, flags, child = cast[StreamServer](result)))
result = cast[Scorper](createStreamServer(address, processClientImpl(isSecurity), flags, child = cast[StreamServer](result)))
result.initScorper()
when defined(ssl):
if isSecurity:
when isSecurity:
result.initSecurityScorper(secureFlags, privateKey, certificate, tlsMinVersion, tlsMaxVersion)

func isClosed*(server: Scorper): bool =
Expand Down

0 comments on commit 410054f

Please sign in to comment.