diff --git a/spec/std/http/server/server_spec.cr b/spec/std/http/server/server_spec.cr index d345472e5a5e..296b2a200343 100644 --- a/spec/std/http/server/server_spec.cr +++ b/spec/std/http/server/server_spec.cr @@ -285,6 +285,66 @@ module HTTP HTTP::Client.get("http://#{address1}/").body.should eq "Test Server (#{address1})" end + it "handles Expect: 100-continue correctly when body is read" do + server = Server.new do |context| + context.response << context.request.body.not_nil!.gets_to_end + end + + address = server.bind_unused_port + spawn server.listen + + wait_for { server.listening? } + + TCPSocket.open(address.address, address.port) do |socket| + socket << requestize(<<-REQUEST + POST / HTTP/1.1 + Expect: 100-continue + Content-Length: 5 + + REQUEST + ) + socket << "\r\n" + socket.flush + + response = Client::Response.from_io(socket) + response.status_code.should eq(100) + + socket << "hello" + socket.flush + + response = Client::Response.from_io(socket) + response.status_code.should eq(200) + response.body.should eq("hello") + end + end + + it "handles Expect: 100-continue correctly when body isn't read" do + server = Server.new do |context| + context.response.respond_with_error("I don't want your body", 400) + end + + address = server.bind_unused_port + spawn server.listen + + wait_for { server.listening? } + + TCPSocket.open(address.address, address.port) do |socket| + socket << requestize(<<-REQUEST + POST / HTTP/1.1 + Expect: 100-continue + Content-Length: 5 + + REQUEST + ) + socket << "\r\n" + socket.flush + + response = Client::Response.from_io(socket) + response.status_code.should eq(400) + response.body.should eq("400 I don't want your body\n") + end + end + it "lists addresses" do server = Server.new { } diff --git a/src/http/common.cr b/src/http/common.cr index 1907bef8c8d5..bc6689c55592 100644 --- a/src/http/common.cr +++ b/src/http/common.cr @@ -27,6 +27,7 @@ module HTTP if line.empty? body = nil + if body_type.prohibited? body = nil elsif content_length = content_length(headers) @@ -40,6 +41,10 @@ module HTTP body = UnknownLengthContent.new(io) end + if body.is_a?(Content) && expect_continue?(headers) + body.expects_continue = true + end + if decompress && body {% if flag?(:without_zlib) %} raise "Can't decompress because `-D without_zlib` was passed at compile time" @@ -194,6 +199,10 @@ module HTTP end end + def self.expect_continue?(headers) + headers["Expect"]?.try(&.downcase) == "100-continue" + end + record ComputedContentTypeHeader, content_type : String?, charset : String? diff --git a/src/http/content.cr b/src/http/content.cr index da5875570482..218821dc127e 100644 --- a/src/http/content.cr +++ b/src/http/content.cr @@ -3,16 +3,50 @@ require "http/common" module HTTP # :nodoc: module Content + CONTINUE = "HTTP/1.1 100 Continue\r\n\r\n" + + @continue_sent = false + setter expects_continue : Bool = false + def close + @expects_continue = false skip_to_end super end + + protected def ensure_send_continue + return unless @expects_continue + return if @continue_sent + @io << CONTINUE + @io.flush + @continue_sent = true + end end # :nodoc: class FixedLengthContent < IO::Sized include Content + def read(slice : Bytes) + ensure_send_continue + super + end + + def read_byte + ensure_send_continue + super + end + + def peek + ensure_send_continue + super + end + + def skip(bytes_count) + ensure_send_continue + super + end + def write(slice : Bytes) raise IO::Error.new "Can't write to FixedLengthContent" end @@ -26,14 +60,17 @@ module HTTP end def read(slice : Bytes) + ensure_send_continue @io.read(slice) end def read_byte + ensure_send_continue @io.read_byte end def peek + ensure_send_continue @io.peek end @@ -66,6 +103,7 @@ module HTTP end def read(slice : Bytes) + ensure_send_continue count = slice.size return 0 if count == 0 @@ -87,6 +125,7 @@ module HTTP end def read_byte + ensure_send_continue next_chunk return super if @received_final_chunk @@ -100,6 +139,7 @@ module HTTP end def peek + ensure_send_continue next_chunk return Bytes.empty if @received_final_chunk @@ -115,6 +155,7 @@ module HTTP end def skip(bytes_count) + ensure_send_continue if bytes_count <= @chunk_remaining @io.skip(bytes_count) @chunk_remaining -= bytes_count