libgo: update to Go 1.7.1 release
authorIan Lance Taylor <ian@gcc.gnu.org>
Sat, 10 Sep 2016 13:14:00 +0000 (13:14 +0000)
committerIan Lance Taylor <ian@gcc.gnu.org>
Sat, 10 Sep 2016 13:14:00 +0000 (13:14 +0000)
    Reviewed-on: https://go-review.googlesource.com/29012

From-SVN: r240071

29 files changed:
gcc/go/gofrontend/MERGE
libgo/MERGE
libgo/VERSION
libgo/go/compress/flate/deflate.go
libgo/go/compress/flate/deflate_test.go
libgo/go/compress/flate/huffman_bit_writer.go
libgo/go/context/context_test.go
libgo/go/crypto/x509/root_cgo_darwin.go
libgo/go/hash/crc32/crc32_s390x.go
libgo/go/io/multi.go
libgo/go/io/multi_test.go
libgo/go/net/dial_unix_test.go [new file with mode: 0644]
libgo/go/net/dnsclient_unix.go
libgo/go/net/dnsclient_unix_test.go
libgo/go/net/fd_unix.go
libgo/go/net/hook_unix.go
libgo/go/net/http/h2_bundle.go
libgo/go/net/http/serve_test.go
libgo/go/net/http/server.go
libgo/go/net/http/transport.go
libgo/go/net/http/transport_internal_test.go
libgo/go/net/http/transport_test.go
libgo/go/os/wait_waitid.go
libgo/go/path/filepath/export_windows_test.go
libgo/go/path/filepath/path_test.go
libgo/go/path/filepath/symlink_windows.go
libgo/go/reflect/all_test.go
libgo/go/runtime/pprof/pprof.go
libgo/go/syscall/syscall_darwin_test.go [new file with mode: 0644]

index 487adbe8e50315aa933e98cf8f1142a88c76684b..7b3a8aa8d63a3838fe0eefa25cdffed8948d4397 100644 (file)
@@ -1,4 +1,4 @@
-04fe765560107e5d4c5f98c1022765930a1069f9
+d3a145b111a4f4ea772b812c6a0b3a853c207819
 
 The first line of this file holds the git revision number of the last
 merge done from the gofrontend repository.
index dc6e379977099d1a564df6a1cbcb511c9ace2058..160cfe323cb9d595370203ce26aa995d8b6551e9 100644 (file)
@@ -1,4 +1,4 @@
-8707f31c0abc6b607014e843b7cc188b3019daa9
+f75aafdf56dd90eab75cfeac8cf69358f73ba171
 
 The first line of this file holds the git revision number of the
 last merge done from the master library sources.
index a323ae8190fee00532a0dfe3417c6c05ff775e9b..ee106b3bb2ddda388231265a1b0a5d2be1b3563d 100644 (file)
@@ -1 +1 @@
-go1.7rc3
+go1.7.1
index 8a085ba34740f9ed9152615dad3e339b30f9bd0b..9f53d51a6e751ffe7157f54efeb5d34afd1e92ed 100644 (file)
@@ -15,7 +15,17 @@ const (
        BestSpeed          = 1
        BestCompression    = 9
        DefaultCompression = -1
-       HuffmanOnly        = -2 // Disables match search and only does Huffman entropy reduction.
+
+       // HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman
+       // entropy encoding. This mode is useful in compressing data that has
+       // already been compressed with an LZ style algorithm (e.g. Snappy or LZ4)
+       // that lacks an entropy encoder. Compression gains are achieved when
+       // certain bytes in the input stream occur more frequently than others.
+       //
+       // Note that HuffmanOnly produces a compressed output that is
+       // RFC 1951 compliant. That is, any valid DEFLATE decompressor will
+       // continue to be able to decompress this output.
+       HuffmanOnly = -2
 )
 
 const (
@@ -644,7 +654,6 @@ func (d *compressor) close() error {
 // a very fast compression for all types of input, but sacrificing considerable
 // compression efficiency.
 //
-//
 // If level is in the range [-2, 9] then the error returned will be nil.
 // Otherwise the error returned will be non-nil.
 func NewWriter(w io.Writer, level int) (*Writer, error) {
@@ -715,7 +724,7 @@ func (w *Writer) Close() error {
 // the result of NewWriter or NewWriterDict called with dst
 // and w's level and dictionary.
 func (w *Writer) Reset(dst io.Writer) {
-       if dw, ok := w.d.w.w.(*dictWriter); ok {
+       if dw, ok := w.d.w.writer.(*dictWriter); ok {
                // w was created with NewWriterDict
                dw.w = dst
                w.d.reset(dw)
index 27a3b3823a649401df3c58f8be497328feb6e7a5..3322c40845d0678d29b51f7326d4cd045c0c593e 100644 (file)
@@ -6,6 +6,7 @@ package flate
 
 import (
        "bytes"
+       "errors"
        "fmt"
        "internal/testenv"
        "io"
@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) {
                }
        }
 }
+
+var errIO = errors.New("IO error")
+
+// failWriter fails with errIO exactly at the nth call to Write.
+type failWriter struct{ n int }
+
+func (w *failWriter) Write(b []byte) (int, error) {
+       w.n--
+       if w.n == -1 {
+               return 0, errIO
+       }
+       return len(b), nil
+}
+
+func TestWriterPersistentError(t *testing.T) {
+       d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
+       if err != nil {
+               t.Fatalf("ReadFile: %v", err)
+       }
+       d = d[:10000] // Keep this test short
+
+       zw, err := NewWriter(nil, DefaultCompression)
+       if err != nil {
+               t.Fatalf("NewWriter: %v", err)
+       }
+
+       // Sweep over the threshold at which an error is returned.
+       // The variable i makes it such that the ith call to failWriter.Write will
+       // return errIO. Since failWriter errors are not persistent, we must ensure
+       // that flate.Writer errors are persistent.
+       for i := 0; i < 1000; i++ {
+               fw := &failWriter{i}
+               zw.Reset(fw)
+
+               _, werr := zw.Write(d)
+               cerr := zw.Close()
+               if werr != errIO && werr != nil {
+                       t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
+               }
+               if cerr != errIO && fw.n < 0 {
+                       t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
+               }
+               if fw.n >= 0 {
+                       // At this point, the failure threshold was sufficiently high enough
+                       // that we wrote the whole stream without any errors.
+                       return
+               }
+       }
+}
index c4adef9ff53010b6db8d86f824bf235dc8b87ffd..d8b5a3ebd7b50c7da64274dca161241e1bbb4d60 100644 (file)
@@ -77,7 +77,11 @@ var offsetBase = []uint32{
 var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
 
 type huffmanBitWriter struct {
-       w io.Writer
+       // writer is the underlying writer.
+       // Do not use it directly; use the write method, which ensures
+       // that Write errors are sticky.
+       writer io.Writer
+
        // Data waiting to be written is bytes[0:nbytes]
        // and then the low nbits of bits.
        bits            uint64
@@ -96,7 +100,7 @@ type huffmanBitWriter struct {
 
 func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
        return &huffmanBitWriter{
-               w:               w,
+               writer:          w,
                literalFreq:     make([]int32, maxNumLit),
                offsetFreq:      make([]int32, offsetCodeCount),
                codegen:         make([]uint8, maxNumLit+offsetCodeCount+1),
@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
 }
 
 func (w *huffmanBitWriter) reset(writer io.Writer) {
-       w.w = writer
+       w.writer = writer
        w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
        w.bytes = [bufferSize]byte{}
 }
@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() {
                n++
        }
        w.bits = 0
-       _, w.err = w.w.Write(w.bytes[:n])
+       w.write(w.bytes[:n])
        w.nbytes = 0
 }
 
+func (w *huffmanBitWriter) write(b []byte) {
+       if w.err != nil {
+               return
+       }
+       _, w.err = w.writer.Write(b)
+}
+
 func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
+       if w.err != nil {
+               return
+       }
        w.bits |= uint64(b) << w.nbits
        w.nbits += nb
        if w.nbits >= 48 {
@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
                bytes[5] = byte(bits >> 40)
                n += 6
                if n >= bufferFlushSize {
-                       _, w.err = w.w.Write(w.bytes[:n])
+                       w.write(w.bytes[:n])
                        n = 0
                }
                w.nbytes = n
@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
                n++
        }
        if n != 0 {
-               _, w.err = w.w.Write(w.bytes[:n])
-               if w.err != nil {
-                       return
-               }
+               w.write(w.bytes[:n])
        }
        w.nbytes = 0
-       _, w.err = w.w.Write(bytes)
+       w.write(bytes)
 }
 
 // RFC 1951 3.2.7 specifies a special run-length encoding for specifying
@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) {
                bytes[5] = byte(bits >> 40)
                n += 6
                if n >= bufferFlushSize {
-                       _, w.err = w.w.Write(w.bytes[:n])
+                       w.write(w.bytes[:n])
                        n = 0
                }
                w.nbytes = n
@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets
 // writeTokens writes a slice of tokens to the output.
 // codes for literal and offset encoding must be supplied.
 func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
+       if w.err != nil {
+               return
+       }
        for _, t := range tokens {
                if t < matchType {
                        w.writeCode(leCodes[t.literal()])
@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
                if n < bufferFlushSize {
                        continue
                }
-               _, w.err = w.w.Write(w.bytes[:n])
+               w.write(w.bytes[:n])
                if w.err != nil {
-                       return
+                       return // Return early in the event of write failures
                }
                n = 0
        }
index 90e78e57ecc42e83623587c5c8dc9a2204ea3df0..cf182110fbdf9b7409b40cfa2efdfe1660c4f248 100644 (file)
@@ -255,6 +255,12 @@ func TestDeadline(t *testing.T) {
        o = otherContext{c}
        c, _ = WithDeadline(o, time.Now().Add(4*time.Second))
        testDeadline(c, "WithDeadline+otherContext+WithDeadline", 2*time.Second, t)
+
+       c, _ = WithDeadline(Background(), time.Now().Add(-time.Millisecond))
+       testDeadline(c, "WithDeadline+inthepast", time.Second, t)
+
+       c, _ = WithDeadline(Background(), time.Now())
+       testDeadline(c, "WithDeadline+now", time.Second, t)
 }
 
 func TestTimeout(t *testing.T) {
index 0e2fb357ee904034b46e96f8f06094bce2fad5b6..a4b33c7660ee7b33a0a417105c23dfc05d3dc2d4 100644 (file)
@@ -10,9 +10,65 @@ package x509
 #cgo CFLAGS: -mmacosx-version-min=10.6 -D__MAC_OS_X_VERSION_MAX_ALLOWED=1060
 #cgo LDFLAGS: -framework CoreFoundation -framework Security
 
+#include <errno.h>
+#include <sys/sysctl.h>
+
 #include <CoreFoundation/CoreFoundation.h>
 #include <Security/Security.h>
 
+// FetchPEMRoots_MountainLion is the version of FetchPEMRoots from Go 1.6
+// which still works on OS X 10.8 (Mountain Lion).
+// It lacks support for admin & user cert domains.
+// See golang.org/issue/16473
+int FetchPEMRoots_MountainLion(CFDataRef *pemRoots) {
+       if (pemRoots == NULL) {
+               return -1;
+       }
+       CFArrayRef certs = NULL;
+       OSStatus err = SecTrustCopyAnchorCertificates(&certs);
+       if (err != noErr) {
+               return -1;
+       }
+       CFMutableDataRef combinedData = CFDataCreateMutable(kCFAllocatorDefault, 0);
+       int i, ncerts = CFArrayGetCount(certs);
+       for (i = 0; i < ncerts; i++) {
+               CFDataRef data = NULL;
+               SecCertificateRef cert = (SecCertificateRef)CFArrayGetValueAtIndex(certs, i);
+               if (cert == NULL) {
+                       continue;
+               }
+               // Note: SecKeychainItemExport is deprecated as of 10.7 in favor of SecItemExport.
+               // Once we support weak imports via cgo we should prefer that, and fall back to this
+               // for older systems.
+               err = SecKeychainItemExport(cert, kSecFormatX509Cert, kSecItemPemArmour, NULL, &data);
+               if (err != noErr) {
+                       continue;
+               }
+               if (data != NULL) {
+                       CFDataAppendBytes(combinedData, CFDataGetBytePtr(data), CFDataGetLength(data));
+                       CFRelease(data);
+               }
+       }
+       CFRelease(certs);
+       *pemRoots = combinedData;
+       return 0;
+}
+
+// useOldCode reports whether the running machine is OS X 10.8 Mountain Lion
+// or older. We only support Mountain Lion and higher, but we'll at least try our
+// best on older machines and continue to use the old code path.
+//
+// See golang.org/issue/16473
+int useOldCode() {
+       char str[256];
+       size_t size = sizeof(str);
+       memset(str, 0, size);
+       sysctlbyname("kern.osrelease", str, &size, NULL, 0);
+       // OS X 10.8 is osrelease "12.*", 10.7 is 11.*, 10.6 is 10.*.
+       // We never supported things before that.
+       return memcmp(str, "12.", 3) == 0 || memcmp(str, "11.", 3) == 0 || memcmp(str, "10.", 3) == 0;
+}
+
 // FetchPEMRoots fetches the system's list of trusted X.509 root certificates.
 //
 // On success it returns 0 and fills pemRoots with a CFDataRef that contains the extracted root
@@ -21,6 +77,10 @@ package x509
 // Note: The CFDataRef returned in pemRoots must be released (using CFRelease) after
 // we've consumed its content.
 int FetchPEMRoots(CFDataRef *pemRoots) {
+       if (useOldCode()) {
+               return FetchPEMRoots_MountainLion(pemRoots);
+       }
+
        // Get certificates from all domains, not just System, this lets
        // the user add CAs to their "login" keychain, and Admins to add
        // to the "System" keychain
index b8a580870c504fb120a4cc08c4a4b3098b9598c4..2a7926b4796008809d2e462417506bbfb6be4159 100644 (file)
@@ -6,14 +6,9 @@
 
 package crc32
 
-import (
-       "unsafe"
-)
-
 const (
        vxMinLen    = 64
-       vxAlignment = 16
-       vxAlignMask = vxAlignment - 1
+       vxAlignMask = 15 // align to 16 bytes
 )
 
 // hasVectorFacility reports whether the machine has the z/Architecture
@@ -51,20 +46,13 @@ func genericIEEE(crc uint32, p []byte) uint32 {
        return update(crc, IEEETable, p)
 }
 
-// updateCastagnoli calculates the checksum of p using genericCastagnoli to
-// align the data appropriately for vectorCastagnoli. It avoids using
-// vectorCastagnoli entirely if the length of p is less than or equal to
-// vxMinLen.
+// updateCastagnoli calculates the checksum of p using
+// vectorizedCastagnoli if possible and falling back onto
+// genericCastagnoli as needed.
 func updateCastagnoli(crc uint32, p []byte) uint32 {
        // Use vectorized function if vector facility is available and
        // data length is above threshold.
-       if hasVX && len(p) > vxMinLen {
-               pAddr := uintptr(unsafe.Pointer(&p[0]))
-               if pAddr&vxAlignMask != 0 {
-                       prealign := vxAlignment - int(pAddr&vxAlignMask)
-                       crc = genericCastagnoli(crc, p[:prealign])
-                       p = p[prealign:]
-               }
+       if hasVX && len(p) >= vxMinLen {
                aligned := len(p) & ^vxAlignMask
                crc = vectorizedCastagnoli(crc, p[:aligned])
                p = p[aligned:]
@@ -77,19 +65,12 @@ func updateCastagnoli(crc uint32, p []byte) uint32 {
        return genericCastagnoli(crc, p)
 }
 
-// updateIEEE calculates the checksum of p using genericIEEE to align the data
-// appropriately for vectorIEEE. It avoids using vectorIEEE entirely if the length
-// of p is less than or equal to vxMinLen.
+// updateIEEE calculates the checksum of p using vectorizedIEEE if
+// possible and falling back onto genericIEEE as needed.
 func updateIEEE(crc uint32, p []byte) uint32 {
        // Use vectorized function if vector facility is available and
        // data length is above threshold.
-       if hasVX && len(p) > vxMinLen {
-               pAddr := uintptr(unsafe.Pointer(&p[0]))
-               if pAddr&vxAlignMask != 0 {
-                       prealign := vxAlignment - int(pAddr&vxAlignMask)
-                       crc = genericIEEE(crc, p[:prealign])
-                       p = p[prealign:]
-               }
+       if hasVX && len(p) >= vxMinLen {
                aligned := len(p) & ^vxAlignMask
                crc = vectorizedIEEE(crc, p[:aligned])
                p = p[aligned:]
index ed05cac9e722d5fcff7663ffe55d188fe83c4283..3a9d03652b07e67e7e09d122c4133e86a89cc3c7 100644 (file)
@@ -18,15 +18,16 @@ func (mr *multiReader) Read(p []byte) (n int, err error) {
                        }
                }
                n, err = mr.readers[0].Read(p)
+               if err == EOF {
+                       mr.readers = mr.readers[1:]
+               }
                if n > 0 || err != EOF {
-                       if err == EOF {
-                               // Don't return EOF yet. There may be more bytes
-                               // in the remaining readers.
+                       if err == EOF && len(mr.readers) > 0 {
+                               // Don't return EOF yet. More readers remain.
                                err = nil
                        }
                        return
                }
-               mr.readers = mr.readers[1:]
        }
        return 0, EOF
 }
index a434453f6b0b691842b898e835ee2fb434f5e96c..447e7f59635270ec799608b19a8103731f56a35b 100644 (file)
@@ -197,3 +197,41 @@ func TestMultiReaderFlatten(t *testing.T) {
                        myDepth+2, readDepth)
        }
 }
+
+// byteAndEOFReader is a Reader which reads one byte (the underlying
+// byte) and io.EOF at once in its Read call.
+type byteAndEOFReader byte
+
+func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
+       if len(p) == 0 {
+               // Read(0 bytes) is useless. We expect no such useless
+               // calls in this test.
+               panic("unexpected call")
+       }
+       p[0] = byte(b)
+       return 1, EOF
+}
+
+// In Go 1.7, this yielded bytes forever.
+func TestMultiReaderSingleByteWithEOF(t *testing.T) {
+       got, err := ioutil.ReadAll(LimitReader(MultiReader(byteAndEOFReader('a'), byteAndEOFReader('b')), 10))
+       if err != nil {
+               t.Fatal(err)
+       }
+       const want = "ab"
+       if string(got) != want {
+               t.Errorf("got %q; want %q", got, want)
+       }
+}
+
+// Test that a reader returning (n, EOF) at the end of an MultiReader
+// chain continues to return EOF on its final read, rather than
+// yielding a (0, EOF).
+func TestMultiReaderFinalEOF(t *testing.T) {
+       r := MultiReader(bytes.NewReader(nil), byteAndEOFReader('a'))
+       buf := make([]byte, 2)
+       n, err := r.Read(buf)
+       if n != 1 || err != EOF {
+               t.Errorf("got %v, %v; want 1, EOF", n, err)
+       }
+}
diff --git a/libgo/go/net/dial_unix_test.go b/libgo/go/net/dial_unix_test.go
new file mode 100644 (file)
index 0000000..4705254
--- /dev/null
@@ -0,0 +1,108 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package net
+
+import (
+       "context"
+       "syscall"
+       "testing"
+       "time"
+)
+
+// Issue 16523
+func TestDialContextCancelRace(t *testing.T) {
+       oldConnectFunc := connectFunc
+       oldGetsockoptIntFunc := getsockoptIntFunc
+       oldTestHookCanceledDial := testHookCanceledDial
+       defer func() {
+               connectFunc = oldConnectFunc
+               getsockoptIntFunc = oldGetsockoptIntFunc
+               testHookCanceledDial = oldTestHookCanceledDial
+       }()
+
+       ln, err := newLocalListener("tcp")
+       if err != nil {
+               t.Fatal(err)
+       }
+       listenerDone := make(chan struct{})
+       go func() {
+               defer close(listenerDone)
+               c, err := ln.Accept()
+               if err == nil {
+                       c.Close()
+               }
+       }()
+       defer func() { <-listenerDone }()
+       defer ln.Close()
+
+       sawCancel := make(chan bool, 1)
+       testHookCanceledDial = func() {
+               sawCancel <- true
+       }
+
+       ctx, cancelCtx := context.WithCancel(context.Background())
+
+       connectFunc = func(fd int, addr syscall.Sockaddr) error {
+               err := oldConnectFunc(fd, addr)
+               t.Logf("connect(%d, addr) = %v", fd, err)
+               if err == nil {
+                       // On some operating systems, localhost
+                       // connects _sometimes_ succeed immediately.
+                       // Prevent that, so we exercise the code path
+                       // we're interested in testing. This seems
+                       // harmless. It makes FreeBSD 10.10 work when
+                       // run with many iterations. It failed about
+                       // half the time previously.
+                       return syscall.EINPROGRESS
+               }
+               return err
+       }
+
+       getsockoptIntFunc = func(fd, level, opt int) (val int, err error) {
+               val, err = oldGetsockoptIntFunc(fd, level, opt)
+               t.Logf("getsockoptIntFunc(%d, %d, %d) = (%v, %v)", fd, level, opt, val, err)
+               if level == syscall.SOL_SOCKET && opt == syscall.SO_ERROR && err == nil && val == 0 {
+                       t.Logf("canceling context")
+
+                       // Cancel the context at just the moment which
+                       // caused the race in issue 16523.
+                       cancelCtx()
+
+                       // And wait for the "interrupter" goroutine to
+                       // cancel the dial by messing with its write
+                       // timeout before returning.
+                       select {
+                       case <-sawCancel:
+                               t.Logf("saw cancel")
+                       case <-time.After(5 * time.Second):
+                               t.Errorf("didn't see cancel after 5 seconds")
+                       }
+               }
+               return
+       }
+
+       var d Dialer
+       c, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+       if err == nil {
+               c.Close()
+               t.Fatal("unexpected successful dial; want context canceled error")
+       }
+
+       select {
+       case <-ctx.Done():
+       case <-time.After(5 * time.Second):
+               t.Fatal("expected context to be canceled")
+       }
+
+       oe, ok := err.(*OpError)
+       if !ok || oe.Op != "dial" {
+               t.Fatalf("Dial error = %#v; want dial *OpError", err)
+       }
+       if oe.Err != ctx.Err() {
+               t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, ctx.Err())
+       }
+}
index 8f2dff46751b0e31778a1fc1e92d030d630c5f19..b5b6ffb1c5060ebb025f984b2e1b1b6c334369d0 100644 (file)
@@ -141,7 +141,7 @@ func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn,
 }
 
 // exchange sends a query on the connection and hopes for a response.
-func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, error) {
+func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
        d := testHookDNSDialer()
        out := dnsMsg{
                dnsMsgHdr: dnsMsgHdr{
@@ -152,6 +152,12 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
                },
        }
        for _, network := range []string{"udp", "tcp"} {
+               // TODO(mdempsky): Refactor so defers from UDP-based
+               // exchanges happen before TCP-based exchange.
+
+               ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+               defer cancel()
+
                c, err := d.dialDNS(ctx, network, server)
                if err != nil {
                        return nil, err
@@ -180,17 +186,10 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16)
                return "", nil, &DNSError{Err: "no DNS servers", Name: name}
        }
 
-       deadline := time.Now().Add(cfg.timeout)
-       if old, ok := ctx.Deadline(); !ok || deadline.Before(old) {
-               var cancel context.CancelFunc
-               ctx, cancel = context.WithDeadline(ctx, deadline)
-               defer cancel()
-       }
-
        var lastErr error
        for i := 0; i < cfg.attempts; i++ {
                for _, server := range cfg.servers {
-                       msg, err := exchange(ctx, server, name, qtype)
+                       msg, err := exchange(ctx, server, name, qtype, cfg.timeout)
                        if err != nil {
                                lastErr = &DNSError{
                                        Err:    err.Error(),
@@ -338,8 +337,9 @@ func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs [
 }
 
 // avoidDNS reports whether this is a hostname for which we should not
-// use DNS. Currently this includes only .onion and .local names,
-// per RFC 7686 and RFC 6762, respectively. See golang.org/issue/13705.
+// use DNS. Currently this includes only .onion, per RFC 7686. See
+// golang.org/issue/13705. Does not cover .local names (RFC 6762),
+// see golang.org/issue/16739.
 func avoidDNS(name string) bool {
        if name == "" {
                return true
@@ -347,7 +347,7 @@ func avoidDNS(name string) bool {
        if name[len(name)-1] == '.' {
                name = name[:len(name)-1]
        }
-       return stringsHasSuffixFold(name, ".onion") || stringsHasSuffixFold(name, ".local")
+       return stringsHasSuffixFold(name, ".onion")
 }
 
 // nameList returns a list of names for sequential DNS queries.
index 09bbd488660673284de1408fa9feb1f2bf6314f6..6ebeeaeb8f48996ac39a67a9a90bb8d2953c44b9 100644 (file)
@@ -40,9 +40,9 @@ func TestDNSTransportFallback(t *testing.T) {
        testenv.MustHaveExternalNetwork(t)
 
        for _, tt := range dnsTransportFallbackTests {
-               ctx, cancel := context.WithTimeout(context.Background(), time.Duration(tt.timeout)*time.Second)
+               ctx, cancel := context.WithCancel(context.Background())
                defer cancel()
-               msg, err := exchange(ctx, tt.server, tt.name, tt.qtype)
+               msg, err := exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
                if err != nil {
                        t.Error(err)
                        continue
@@ -82,9 +82,9 @@ func TestSpecialDomainName(t *testing.T) {
 
        server := "8.8.8.8:53"
        for _, tt := range specialDomainNameTests {
-               ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+               ctx, cancel := context.WithCancel(context.Background())
                defer cancel()
-               msg, err := exchange(ctx, server, tt.name, tt.qtype)
+               msg, err := exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
                if err != nil {
                        t.Error(err)
                        continue
@@ -112,10 +112,11 @@ func TestAvoidDNSName(t *testing.T) {
                {"foo.ONION", true},
                {"foo.ONION.", true},
 
-               {"foo.local.", true},
-               {"foo.local", true},
-               {"foo.LOCAL", true},
-               {"foo.LOCAL.", true},
+               // But do resolve *.local address; Issue 16739
+               {"foo.local.", false},
+               {"foo.local", false},
+               {"foo.LOCAL", false},
+               {"foo.LOCAL.", false},
 
                {"", true}, // will be rejected earlier too
 
@@ -500,7 +501,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
        d := &fakeDNSDialer{}
        testHookDNSDialer = func() dnsDialer { return d }
 
-       d.rh = func(s string, q *dnsMsg) (*dnsMsg, error) {
+       d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
                                id: q.id,
@@ -539,14 +540,15 @@ func TestIgnoreLameReferrals(t *testing.T) {
        }
        defer conf.teardown()
 
-       if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", "nameserver 192.0.2.2"}); err != nil {
+       if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral
+               "nameserver 192.0.2.2"}); err != nil {
                t.Fatal(err)
        }
 
        d := &fakeDNSDialer{}
        testHookDNSDialer = func() dnsDialer { return d }
 
-       d.rh = func(s string, q *dnsMsg) (*dnsMsg, error) {
+       d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
                t.Log(s, q)
                r := &dnsMsg{
                        dnsMsgHdr: dnsMsgHdr{
@@ -633,28 +635,30 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
 
 type fakeDNSDialer struct {
        // reply handler
-       rh func(s string, q *dnsMsg) (*dnsMsg, error)
+       rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
 }
 
 func (f *fakeDNSDialer) dialDNS(_ context.Context, n, s string) (dnsConn, error) {
-       return &fakeDNSConn{f.rh, s}, nil
+       return &fakeDNSConn{f.rh, s, time.Time{}}, nil
 }
 
 type fakeDNSConn struct {
-       rh func(s string, q *dnsMsg) (*dnsMsg, error)
+       rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
        s  string
+       t  time.Time
 }
 
 func (f *fakeDNSConn) Close() error {
        return nil
 }
 
-func (f *fakeDNSConn) SetDeadline(time.Time) error {
+func (f *fakeDNSConn) SetDeadline(t time.Time) error {
+       f.t = t
        return nil
 }
 
 func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
-       return f.rh(f.s, q)
+       return f.rh(f.s, q, f.t)
 }
 
 // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
@@ -724,3 +728,72 @@ func TestIgnoreDNSForgeries(t *testing.T) {
                t.Errorf("got address %v, want %v", got, TestAddr)
        }
 }
+
+// Issue 16865. If a name server times out, continue to the next.
+func TestRetryTimeout(t *testing.T) {
+       origTestHookDNSDialer := testHookDNSDialer
+       defer func() { testHookDNSDialer = origTestHookDNSDialer }()
+
+       conf, err := newResolvConfTest()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer conf.teardown()
+
+       if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will timeout
+               "nameserver 192.0.2.2"}); err != nil {
+               t.Fatal(err)
+       }
+
+       d := &fakeDNSDialer{}
+       testHookDNSDialer = func() dnsDialer { return d }
+
+       var deadline0 time.Time
+
+       d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+               t.Log(s, q, deadline)
+
+               if deadline.IsZero() {
+                       t.Error("zero deadline")
+               }
+
+               if s == "192.0.2.1:53" {
+                       deadline0 = deadline
+                       time.Sleep(10 * time.Millisecond)
+                       return nil, errTimeout
+               }
+
+               if deadline == deadline0 {
+                       t.Error("deadline didn't change")
+               }
+
+               r := &dnsMsg{
+                       dnsMsgHdr: dnsMsgHdr{
+                               id:                  q.id,
+                               response:            true,
+                               recursion_available: true,
+                       },
+                       question: q.question,
+                       answer: []dnsRR{
+                               &dnsRR_CNAME{
+                                       Hdr: dnsRR_Header{
+                                               Name:   q.question[0].Name,
+                                               Rrtype: dnsTypeCNAME,
+                                               Class:  dnsClassINET,
+                                       },
+                                       Cname: "golang.org",
+                               },
+                       },
+               }
+               return r, nil
+       }
+
+       _, err = goLookupCNAME(context.Background(), "www.golang.org")
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if deadline0.IsZero() {
+               t.Error("deadline0 still zero", deadline0)
+       }
+}
index 6fbb9cbf1140c18d52a5e323732cc3d3e05363a4..0309db08ebce468dd29b67842a9a8313abfee7a1 100644 (file)
@@ -64,7 +64,7 @@ func (fd *netFD) name() string {
        return fd.net + ":" + ls + "->" + rs
 }
 
-func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret error) {
        // Do not need to call fd.writeLock here,
        // because fd is not yet accessible to user,
        // so no concurrent operations are possible.
@@ -101,21 +101,44 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
                defer fd.setWriteDeadline(noDeadline)
        }
 
-       // Wait for the goroutine converting context.Done into a write timeout
-       // to exist, otherwise our caller might cancel the context and
-       // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
-       done := make(chan bool) // must be unbuffered
-       defer func() { done <- true }()
-       go func() {
-               select {
-               case <-ctx.Done():
-                       // Force the runtime's poller to immediately give
-                       // up waiting for writability.
-                       fd.setWriteDeadline(aLongTimeAgo)
-                       <-done
-               case <-done:
-               }
-       }()
+       // Start the "interrupter" goroutine, if this context might be canceled.
+       // (The background context cannot)
+       //
+       // The interrupter goroutine waits for the context to be done and
+       // interrupts the dial (by altering the fd's write deadline, which
+       // wakes up waitWrite).
+       if ctx != context.Background() {
+               // Wait for the interrupter goroutine to exit before returning
+               // from connect.
+               done := make(chan struct{})
+               interruptRes := make(chan error)
+               defer func() {
+                       close(done)
+                       if ctxErr := <-interruptRes; ctxErr != nil && ret == nil {
+                               // The interrupter goroutine called setWriteDeadline,
+                               // but the connect code below had returned from
+                               // waitWrite already and did a successful connect (ret
+                               // == nil). Because we've now poisoned the connection
+                               // by making it unwritable, don't return a successful
+                               // dial. This was issue 16523.
+                               ret = ctxErr
+                               fd.Close() // prevent a leak
+                       }
+               }()
+               go func() {
+                       select {
+                       case <-ctx.Done():
+                               // Force the runtime's poller to immediately give up
+                               // waiting for writability, unblocking waitWrite
+                               // below.
+                               fd.setWriteDeadline(aLongTimeAgo)
+                               testHookCanceledDial()
+                               interruptRes <- ctx.Err()
+                       case <-done:
+                               interruptRes <- nil
+                       }
+               }()
+       }
 
        for {
                // Performing multiple connect system calls on a
index 361ca5980c38b24abc526e8a346fa6c2c0114efa..cf52567fcfdad88ce65d3f5d39650bdb4e7a6610 100644 (file)
@@ -9,7 +9,8 @@ package net
 import "syscall"
 
 var (
-       testHookDialChannel = func() {} // see golang.org/issue/5349
+       testHookDialChannel  = func() {} // for golang.org/issue/5349
+       testHookCanceledDial = func() {} // for golang.org/issue/16523
 
        // Placeholders for socket system calls.
        socketFunc        func(int, int, int) (int, error)         = syscall.Socket
index db774554b2ca94f736e302d7def0f3cd4bf02e05..5826bb7d858a4137d10e1b584dea640fa5dc3137 100644 (file)
@@ -28,6 +28,7 @@ import (
        "io"
        "io/ioutil"
        "log"
+       "math"
        "net"
        "net/http/httptrace"
        "net/textproto"
@@ -85,7 +86,16 @@ const (
        http2noDialOnMiss = false
 )
 
-func (p *http2clientConnPool) getClientConn(_ *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+       if http2isConnectionCloseRequest(req) && dialOnMiss {
+               // It gets its own connection.
+               const singleUse = true
+               cc, err := p.t.dialClientConn(addr, singleUse)
+               if err != nil {
+                       return nil, err
+               }
+               return cc, nil
+       }
        p.mu.Lock()
        for _, cc := range p.conns[addr] {
                if cc.CanTakeNewRequest() {
@@ -128,7 +138,8 @@ func (p *http2clientConnPool) getStartDialLocked(addr string) *http2dialCall {
 
 // run in its own goroutine.
 func (c *http2dialCall) dial(addr string) {
-       c.res, c.err = c.p.t.dialClientConn(addr)
+       const singleUse = false // shared conn
+       c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
        close(c.done)
 
        c.p.mu.Lock()
@@ -393,9 +404,17 @@ func (e http2ConnectionError) Error() string {
 type http2StreamError struct {
        StreamID uint32
        Code     http2ErrCode
+       Cause    error // optional additional detail
+}
+
+func http2streamError(id uint32, code http2ErrCode) http2StreamError {
+       return http2StreamError{StreamID: id, Code: code}
 }
 
 func (e http2StreamError) Error() string {
+       if e.Cause != nil {
+               return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
+       }
        return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
 }
 
@@ -1105,6 +1124,7 @@ func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error
 var (
        http2errStreamID    = errors.New("invalid stream ID")
        http2errDepStreamID = errors.New("invalid dependent stream ID")
+       http2errPadLength   = errors.New("pad length too large")
 )
 
 func http2validStreamIDOrZero(streamID uint32) bool {
@@ -1118,18 +1138,40 @@ func http2validStreamID(streamID uint32) bool {
 // WriteData writes a DATA frame.
 //
 // It will perform exactly one Write to the underlying Writer.
-// It is the caller's responsibility to not call other Write methods concurrently.
+// It is the caller's responsibility not to violate the maximum frame size
+// and to not call other Write methods concurrently.
 func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error {
+       return f.WriteDataPadded(streamID, endStream, data, nil)
+}
 
+// WriteData writes a DATA frame with optional padding.
+//
+// If pad is nil, the padding bit is not sent.
+// The length of pad must not exceed 255 bytes.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility not to violate the maximum frame size
+// and to not call other Write methods concurrently.
+func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error {
        if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
                return http2errStreamID
        }
+       if len(pad) > 255 {
+               return http2errPadLength
+       }
        var flags http2Flags
        if endStream {
                flags |= http2FlagDataEndStream
        }
+       if pad != nil {
+               flags |= http2FlagDataPadded
+       }
        f.startWrite(http2FrameData, flags, streamID)
+       if pad != nil {
+               f.wbuf = append(f.wbuf, byte(len(pad)))
+       }
        f.wbuf = append(f.wbuf, data...)
+       f.wbuf = append(f.wbuf, pad...)
        return f.endWrite()
 }
 
@@ -1333,7 +1375,7 @@ func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, err
                if fh.StreamID == 0 {
                        return nil, http2ConnectionError(http2ErrCodeProtocol)
                }
-               return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol}
+               return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
        }
        return &http2WindowUpdateFrame{
                http2FrameHeader: fh,
@@ -1411,7 +1453,7 @@ func http2parseHeadersFrame(fh http2FrameHeader, p []byte) (_ http2Frame, err er
                }
        }
        if len(p)-int(padLength) <= 0 {
-               return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol}
+               return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
        }
        hf.headerFragBuf = p[:len(p)-int(padLength)]
        return hf, nil
@@ -1878,6 +1920,9 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr
        hdec.SetEmitEnabled(true)
        hdec.SetMaxStringLength(fr.maxHeaderStringLen())
        hdec.SetEmitFunc(func(hf hpack.HeaderField) {
+               if http2VerboseLogs && http2logFrameReads {
+                       log.Printf("http2: decoded hpack field %+v", hf)
+               }
                if !httplex.ValidHeaderFieldValue(hf.Value) {
                        invalid = http2headerFieldValueError(hf.Value)
                }
@@ -1936,11 +1981,17 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr
        }
        if invalid != nil {
                fr.errDetail = invalid
-               return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol}
+               if http2VerboseLogs {
+                       log.Printf("http2: invalid header: %v", invalid)
+               }
+               return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid}
        }
        if err := mh.checkPseudos(); err != nil {
                fr.errDetail = err
-               return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol}
+               if http2VerboseLogs {
+                       log.Printf("http2: invalid pseudo headers: %v", err)
+               }
+               return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err}
        }
        return mh, nil
 }
@@ -3571,7 +3622,7 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
                case http2stateOpen:
 
                        st.state = http2stateHalfClosedLocal
-                       errCancel := http2StreamError{st.id, http2ErrCodeCancel}
+                       errCancel := http2streamError(st.id, http2ErrCodeCancel)
                        sc.resetStream(errCancel)
                case http2stateHalfClosedRemote:
                        sc.closeStream(st, http2errHandlerComplete)
@@ -3764,7 +3815,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error
                        return nil
                }
                if !st.flow.add(int32(f.Increment)) {
-                       return http2StreamError{f.StreamID, http2ErrCodeFlowControl}
+                       return http2streamError(f.StreamID, http2ErrCodeFlowControl)
                }
        default:
                if !sc.flow.add(int32(f.Increment)) {
@@ -3786,7 +3837,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
        if st != nil {
                st.gotReset = true
                st.cancelCtx()
-               sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode})
+               sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode))
        }
        return nil
 }
@@ -3803,6 +3854,9 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) {
        }
        delete(sc.streams, st.id)
        if p := st.body; p != nil {
+
+               sc.sendWindowUpdate(nil, p.Len())
+
                p.CloseWithError(err)
        }
        st.cw.Close()
@@ -3879,36 +3933,51 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error {
 
 func (sc *http2serverConn) processData(f *http2DataFrame) error {
        sc.serveG.check()
+       data := f.Data()
 
        id := f.Header().StreamID
        st, ok := sc.streams[id]
        if !ok || st.state != http2stateOpen || st.gotTrailerHeader {
 
-               return http2StreamError{id, http2ErrCodeStreamClosed}
+               if sc.inflow.available() < int32(f.Length) {
+                       return http2streamError(id, http2ErrCodeFlowControl)
+               }
+
+               sc.inflow.take(int32(f.Length))
+               sc.sendWindowUpdate(nil, int(f.Length))
+
+               return http2streamError(id, http2ErrCodeStreamClosed)
        }
        if st.body == nil {
                panic("internal error: should have a body in this state")
        }
-       data := f.Data()
 
        if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
                st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
-               return http2StreamError{id, http2ErrCodeStreamClosed}
+               return http2streamError(id, http2ErrCodeStreamClosed)
        }
-       if len(data) > 0 {
+       if f.Length > 0 {
 
-               if int(st.inflow.available()) < len(data) {
-                       return http2StreamError{id, http2ErrCodeFlowControl}
+               if st.inflow.available() < int32(f.Length) {
+                       return http2streamError(id, http2ErrCodeFlowControl)
                }
-               st.inflow.take(int32(len(data)))
-               wrote, err := st.body.Write(data)
-               if err != nil {
-                       return http2StreamError{id, http2ErrCodeStreamClosed}
+               st.inflow.take(int32(f.Length))
+
+               if len(data) > 0 {
+                       wrote, err := st.body.Write(data)
+                       if err != nil {
+                               return http2streamError(id, http2ErrCodeStreamClosed)
+                       }
+                       if wrote != len(data) {
+                               panic("internal error: bad Writer")
+                       }
+                       st.bodyBytes += int64(len(data))
                }
-               if wrote != len(data) {
-                       panic("internal error: bad Writer")
+
+               if pad := int32(f.Length) - int32(len(data)); pad > 0 {
+                       sc.sendWindowUpdate32(nil, pad)
+                       sc.sendWindowUpdate32(st, pad)
                }
-               st.bodyBytes += int64(len(data))
        }
        if f.StreamEnded() {
                st.endStream()
@@ -3995,10 +4064,10 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
 
                if sc.unackedSettings == 0 {
 
-                       return http2StreamError{st.id, http2ErrCodeProtocol}
+                       return http2streamError(st.id, http2ErrCodeProtocol)
                }
 
-               return http2StreamError{st.id, http2ErrCodeRefusedStream}
+               return http2streamError(st.id, http2ErrCodeRefusedStream)
        }
 
        rw, req, err := sc.newWriterAndRequest(st, f)
@@ -4032,18 +4101,18 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
        }
        st.gotTrailerHeader = true
        if !f.StreamEnded() {
-               return http2StreamError{st.id, http2ErrCodeProtocol}
+               return http2streamError(st.id, http2ErrCodeProtocol)
        }
 
        if len(f.PseudoFields()) > 0 {
-               return http2StreamError{st.id, http2ErrCodeProtocol}
+               return http2streamError(st.id, http2ErrCodeProtocol)
        }
        if st.trailer != nil {
                for _, hf := range f.RegularFields() {
                        key := sc.canonicalHeader(hf.Name)
                        if !http2ValidTrailerHeader(key) {
 
-                               return http2StreamError{st.id, http2ErrCodeProtocol}
+                               return http2streamError(st.id, http2ErrCodeProtocol)
                        }
                        st.trailer[key] = append(st.trailer[key], hf.Value)
                }
@@ -4097,18 +4166,18 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
        isConnect := method == "CONNECT"
        if isConnect {
                if path != "" || scheme != "" || authority == "" {
-                       return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol}
+                       return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
                }
        } else if method == "" || path == "" ||
                (scheme != "https" && scheme != "http") {
 
-               return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol}
+               return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
        }
 
        bodyOpen := !f.StreamEnded()
        if method == "HEAD" && bodyOpen {
 
-               return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol}
+               return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
        }
        var tlsState *tls.ConnectionState // nil if not scheme https
 
@@ -4165,7 +4234,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
                var err error
                url_, err = url.ParseRequestURI(path)
                if err != nil {
-                       return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol}
+                       return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
                }
                requestURI = path
        }
@@ -4919,35 +4988,37 @@ func (t *http2Transport) initConnPool() {
 // ClientConn is the state of a single HTTP/2 client connection to an
 // HTTP/2 server.
 type http2ClientConn struct {
-       t        *http2Transport
-       tconn    net.Conn             // usually *tls.Conn, except specialized impls
-       tlsState *tls.ConnectionState // nil only for specialized impls
+       t         *http2Transport
+       tconn     net.Conn             // usually *tls.Conn, except specialized impls
+       tlsState  *tls.ConnectionState // nil only for specialized impls
+       singleUse bool                 // whether being used for a single http.Request
 
        // readLoop goroutine fields:
        readerDone chan struct{} // closed on error
        readerErr  error         // set before readerDone is closed
 
-       mu           sync.Mutex // guards following
-       cond         *sync.Cond // hold mu; broadcast on flow/closed changes
-       flow         http2flow  // our conn-level flow control quota (cs.flow is per stream)
-       inflow       http2flow  // peer's conn-level flow control
-       closed       bool
-       goAway       *http2GoAwayFrame             // if non-nil, the GoAwayFrame we received
-       goAwayDebug  string                        // goAway frame's debug data, retained as a string
-       streams      map[uint32]*http2clientStream // client-initiated
-       nextStreamID uint32
-       bw           *bufio.Writer
-       br           *bufio.Reader
-       fr           *http2Framer
-       lastActive   time.Time
-
-       // Settings from peer:
+       mu              sync.Mutex // guards following
+       cond            *sync.Cond // hold mu; broadcast on flow/closed changes
+       flow            http2flow  // our conn-level flow control quota (cs.flow is per stream)
+       inflow          http2flow  // peer's conn-level flow control
+       closed          bool
+       wantSettingsAck bool                          // we sent a SETTINGS frame and haven't heard back
+       goAway          *http2GoAwayFrame             // if non-nil, the GoAwayFrame we received
+       goAwayDebug     string                        // goAway frame's debug data, retained as a string
+       streams         map[uint32]*http2clientStream // client-initiated
+       nextStreamID    uint32
+       bw              *bufio.Writer
+       br              *bufio.Reader
+       fr              *http2Framer
+       lastActive      time.Time
+       // Settings from peer: (also guarded by mu)
        maxFrameSize         uint32
        maxConcurrentStreams uint32
        initialWindowSize    uint32
-       hbuf                 bytes.Buffer // HPACK encoder writes into this
-       henc                 *hpack.Encoder
-       freeBuf              [][]byte
+
+       hbuf    bytes.Buffer // HPACK encoder writes into this
+       henc    *hpack.Encoder
+       freeBuf [][]byte
 
        wmu  sync.Mutex // held while writing; acquire AFTER mu if holding both
        werr error      // first write error that has occurred
@@ -5117,7 +5188,7 @@ func http2shouldRetryRequest(req *Request, err error) bool {
        return err == http2errClientConnUnusable
 }
 
-func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) {
+func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) {
        host, _, err := net.SplitHostPort(addr)
        if err != nil {
                return nil, err
@@ -5126,7 +5197,7 @@ func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) {
        if err != nil {
                return nil, err
        }
-       return t.NewClientConn(tconn)
+       return t.newClientConn(tconn, singleUse)
 }
 
 func (t *http2Transport) newTLSConfig(host string) *tls.Config {
@@ -5187,14 +5258,10 @@ func (t *http2Transport) expectContinueTimeout() time.Duration {
 }
 
 func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
-       if http2VerboseLogs {
-               t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
-       }
-       if _, err := c.Write(http2clientPreface); err != nil {
-               t.vlogf("client preface write error: %v", err)
-               return nil, err
-       }
+       return t.newClientConn(c, false)
+}
 
+func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) {
        cc := &http2ClientConn{
                t:                    t,
                tconn:                c,
@@ -5204,7 +5271,13 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
                initialWindowSize:    65535,
                maxConcurrentStreams: 1000,
                streams:              make(map[uint32]*http2clientStream),
+               singleUse:            singleUse,
+               wantSettingsAck:      true,
        }
+       if http2VerboseLogs {
+               t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
+       }
+
        cc.cond = sync.NewCond(&cc.mu)
        cc.flow.add(int32(http2initialWindowSize))
 
@@ -5228,6 +5301,8 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
        if max := t.maxHeaderListSize(); max != 0 {
                initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max})
        }
+
+       cc.bw.Write(http2clientPreface)
        cc.fr.WriteSettings(initialSettings...)
        cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow)
        cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize)
@@ -5236,32 +5311,6 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
                return nil, cc.werr
        }
 
-       f, err := cc.fr.ReadFrame()
-       if err != nil {
-               return nil, err
-       }
-       sf, ok := f.(*http2SettingsFrame)
-       if !ok {
-               return nil, fmt.Errorf("expected settings frame, got: %T", f)
-       }
-       cc.fr.WriteSettingsAck()
-       cc.bw.Flush()
-
-       sf.ForeachSetting(func(s http2Setting) error {
-               switch s.ID {
-               case http2SettingMaxFrameSize:
-                       cc.maxFrameSize = s.Val
-               case http2SettingMaxConcurrentStreams:
-                       cc.maxConcurrentStreams = s.Val
-               case http2SettingInitialWindowSize:
-                       cc.initialWindowSize = s.Val
-               default:
-
-                       t.vlogf("Unhandled Setting: %v", s)
-               }
-               return nil
-       })
-
        go cc.readLoop()
        return cc, nil
 }
@@ -5288,9 +5337,12 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool {
 }
 
 func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+       if cc.singleUse && cc.nextStreamID > 1 {
+               return false
+       }
        return cc.goAway == nil && !cc.closed &&
                int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
-               cc.nextStreamID < 2147483647
+               cc.nextStreamID < math.MaxInt32
 }
 
 func (cc *http2ClientConn) closeIfIdle() {
@@ -5300,9 +5352,13 @@ func (cc *http2ClientConn) closeIfIdle() {
                return
        }
        cc.closed = true
+       nextID := cc.nextStreamID
 
        cc.mu.Unlock()
 
+       if http2VerboseLogs {
+               cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2)
+       }
        cc.tconn.Close()
 }
 
@@ -5404,12 +5460,15 @@ func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) {
        // We have a body but a zero content length. Test to see if
        // it's actually zero or just unset.
        var buf [1]byte
-       n, rerr := io.ReadFull(body, buf[:])
+       n, rerr := body.Read(buf[:])
        if rerr != nil && rerr != io.EOF {
                return http2errorReader{rerr}, -1
        }
        if n == 1 {
 
+               if rerr == io.EOF {
+                       return bytes.NewReader(buf[:]), 1
+               }
                return io.MultiReader(bytes.NewReader(buf[:]), body), -1
        }
 
@@ -5494,9 +5553,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
        bodyWritten := false
        ctx := http2reqContext(req)
 
-       reFunc := func(re http2resAndError) (*Response, error) {
+       handleReadLoopResponse := func(re http2resAndError) (*Response, error) {
                res := re.res
                if re.err != nil || res.StatusCode > 299 {
+
                        bodyWriter.cancel()
                        cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
                }
@@ -5512,7 +5572,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
        for {
                select {
                case re := <-readLoopResCh:
-                       return reFunc(re)
+                       return handleReadLoopResponse(re)
                case <-respHeaderTimer:
                        cc.forgetStreamID(cs.ID)
                        if !hasBody || bodyWritten {
@@ -5525,7 +5585,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
                case <-ctx.Done():
                        select {
                        case re := <-readLoopResCh:
-                               return reFunc(re)
+                               return handleReadLoopResponse(re)
                        default:
                        }
                        cc.forgetStreamID(cs.ID)
@@ -5539,7 +5599,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
                case <-req.Cancel:
                        select {
                        case re := <-readLoopResCh:
-                               return reFunc(re)
+                               return handleReadLoopResponse(re)
                        default:
                        }
                        cc.forgetStreamID(cs.ID)
@@ -5553,14 +5613,15 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
                case <-cs.peerReset:
                        select {
                        case re := <-readLoopResCh:
-                               return reFunc(re)
+                               return handleReadLoopResponse(re)
                        default:
                        }
                        return nil, cs.resetErr
                case err := <-bodyWriter.resc:
+
                        select {
                        case re := <-readLoopResCh:
-                               return reFunc(re)
+                               return handleReadLoopResponse(re)
                        default:
                        }
                        if err != nil {
@@ -5670,26 +5731,29 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
                }
        }
 
+       if sentEnd {
+
+               return nil
+       }
+
+       var trls []byte
+       if hasTrailers {
+               cc.mu.Lock()
+               defer cc.mu.Unlock()
+               trls = cc.encodeTrailers(req)
+       }
+
        cc.wmu.Lock()
-       if !sentEnd {
-               var trls []byte
-               if hasTrailers {
-                       cc.mu.Lock()
-                       trls = cc.encodeTrailers(req)
-                       cc.mu.Unlock()
-               }
+       defer cc.wmu.Unlock()
 
-               if len(trls) > 0 {
-                       err = cc.writeHeaders(cs.ID, true, trls)
-               } else {
-                       err = cc.fr.WriteData(cs.ID, true, nil)
-               }
+       if len(trls) > 0 {
+               err = cc.writeHeaders(cs.ID, true, trls)
+       } else {
+               err = cc.fr.WriteData(cs.ID, true, nil)
        }
        if ferr := cc.bw.Flush(); ferr != nil && err == nil {
                err = ferr
        }
-       cc.wmu.Unlock()
-
        return err
 }
 
@@ -5918,6 +5982,14 @@ func (e http2GoAwayError) Error() string {
                e.LastStreamID, e.ErrCode, e.DebugData)
 }
 
+func http2isEOFOrNetReadError(err error) bool {
+       if err == io.EOF {
+               return true
+       }
+       ne, ok := err.(*net.OpError)
+       return ok && ne.Op == "read"
+}
+
 func (rl *http2clientConnReadLoop) cleanup() {
        cc := rl.cc
        defer cc.tconn.Close()
@@ -5926,16 +5998,14 @@ func (rl *http2clientConnReadLoop) cleanup() {
 
        err := cc.readerErr
        cc.mu.Lock()
-       if err == io.EOF {
-               if cc.goAway != nil {
-                       err = http2GoAwayError{
-                               LastStreamID: cc.goAway.LastStreamID,
-                               ErrCode:      cc.goAway.ErrCode,
-                               DebugData:    cc.goAwayDebug,
-                       }
-               } else {
-                       err = io.ErrUnexpectedEOF
+       if cc.goAway != nil && http2isEOFOrNetReadError(err) {
+               err = http2GoAwayError{
+                       LastStreamID: cc.goAway.LastStreamID,
+                       ErrCode:      cc.goAway.ErrCode,
+                       DebugData:    cc.goAwayDebug,
                }
+       } else if err == io.EOF {
+               err = io.ErrUnexpectedEOF
        }
        for _, cs := range rl.activeRes {
                cs.bufPipe.CloseWithError(err)
@@ -5954,16 +6024,21 @@ func (rl *http2clientConnReadLoop) cleanup() {
 
 func (rl *http2clientConnReadLoop) run() error {
        cc := rl.cc
-       rl.closeWhenIdle = cc.t.disableKeepAlives()
+       rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
        gotReply := false
+       gotSettings := false
        for {
                f, err := cc.fr.ReadFrame()
                if err != nil {
-                       cc.vlogf("Transport readFrame error: (%T) %v", err, err)
+                       cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
                }
                if se, ok := err.(http2StreamError); ok {
                        if cs := cc.streamByID(se.StreamID, true); cs != nil {
-                               rl.endStreamError(cs, cc.fr.errDetail)
+                               cs.cc.writeStreamReset(cs.ID, se.Code, err)
+                               if se.Cause == nil {
+                                       se.Cause = cc.fr.errDetail
+                               }
+                               rl.endStreamError(cs, se)
                        }
                        continue
                } else if err != nil {
@@ -5972,6 +6047,13 @@ func (rl *http2clientConnReadLoop) run() error {
                if http2VerboseLogs {
                        cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
                }
+               if !gotSettings {
+                       if _, ok := f.(*http2SettingsFrame); !ok {
+                               cc.logf("protocol error: received %T before a SETTINGS frame", f)
+                               return http2ConnectionError(http2ErrCodeProtocol)
+                       }
+                       gotSettings = true
+               }
                maybeIdle := false
 
                switch f := f.(type) {
@@ -6000,6 +6082,9 @@ func (rl *http2clientConnReadLoop) run() error {
                        cc.logf("Transport: unhandled response frame type %T", f)
                }
                if err != nil {
+                       if http2VerboseLogs {
+                               cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err)
+                       }
                        return err
                }
                if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
@@ -6238,10 +6323,27 @@ var http2errClosedResponseBody = errors.New("http2: response body closed")
 
 func (b http2transportResponseBody) Close() error {
        cs := b.cs
-       if cs.bufPipe.Err() != io.EOF {
+       cc := cs.cc
 
-               cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+       serverSentStreamEnd := cs.bufPipe.Err() == io.EOF
+       unread := cs.bufPipe.Len()
+
+       if unread > 0 || !serverSentStreamEnd {
+               cc.mu.Lock()
+               cc.wmu.Lock()
+               if !serverSentStreamEnd {
+                       cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel)
+               }
+
+               if unread > 0 {
+                       cc.inflow.add(int32(unread))
+                       cc.fr.WriteWindowUpdate(0, uint32(unread))
+               }
+               cc.bw.Flush()
+               cc.wmu.Unlock()
+               cc.mu.Unlock()
        }
+
        cs.bufPipe.BreakWithError(http2errClosedResponseBody)
        return nil
 }
@@ -6249,6 +6351,7 @@ func (b http2transportResponseBody) Close() error {
 func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
        cc := rl.cc
        cs := cc.streamByID(f.StreamID, f.StreamEnded())
+       data := f.Data()
        if cs == nil {
                cc.mu.Lock()
                neverSent := cc.nextStreamID
@@ -6259,27 +6362,49 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
                        return http2ConnectionError(http2ErrCodeProtocol)
                }
 
+               if f.Length > 0 {
+                       cc.mu.Lock()
+                       cc.inflow.add(int32(f.Length))
+                       cc.mu.Unlock()
+
+                       cc.wmu.Lock()
+                       cc.fr.WriteWindowUpdate(0, uint32(f.Length))
+                       cc.bw.Flush()
+                       cc.wmu.Unlock()
+               }
                return nil
        }
-       if data := f.Data(); len(data) > 0 {
-               if cs.bufPipe.b == nil {
+       if f.Length > 0 {
+               if len(data) > 0 && cs.bufPipe.b == nil {
 
                        cc.logf("http2: Transport received DATA frame for closed stream; closing connection")
                        return http2ConnectionError(http2ErrCodeProtocol)
                }
 
                cc.mu.Lock()
-               if cs.inflow.available() >= int32(len(data)) {
-                       cs.inflow.take(int32(len(data)))
+               if cs.inflow.available() >= int32(f.Length) {
+                       cs.inflow.take(int32(f.Length))
                } else {
                        cc.mu.Unlock()
                        return http2ConnectionError(http2ErrCodeFlowControl)
                }
+
+               if pad := int32(f.Length) - int32(len(data)); pad > 0 {
+                       cs.inflow.add(pad)
+                       cc.inflow.add(pad)
+                       cc.wmu.Lock()
+                       cc.fr.WriteWindowUpdate(0, uint32(pad))
+                       cc.fr.WriteWindowUpdate(cs.ID, uint32(pad))
+                       cc.bw.Flush()
+                       cc.wmu.Unlock()
+               }
                cc.mu.Unlock()
 
-               if _, err := cs.bufPipe.Write(data); err != nil {
-                       rl.endStreamError(cs, err)
-                       return err
+               if len(data) > 0 {
+                       if _, err := cs.bufPipe.Write(data); err != nil {
+                               rl.endStreamError(cs, err)
+                               return err
+                       }
                }
        }
 
@@ -6304,9 +6429,14 @@ func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err err
        }
        cs.bufPipe.closeWithErrorAndCode(err, code)
        delete(rl.activeRes, cs.ID)
-       if cs.req.Close || cs.req.Header.Get("Connection") == "close" {
+       if http2isConnectionCloseRequest(cs.req) {
                rl.closeWhenIdle = true
        }
+
+       select {
+       case cs.resc <- http2resAndError{err: err}:
+       default:
+       }
 }
 
 func (cs *http2clientStream) copyTrailers() {
@@ -6334,7 +6464,16 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error
        cc := rl.cc
        cc.mu.Lock()
        defer cc.mu.Unlock()
-       return f.ForeachSetting(func(s http2Setting) error {
+
+       if f.IsAck() {
+               if cc.wantSettingsAck {
+                       cc.wantSettingsAck = false
+                       return nil
+               }
+               return http2ConnectionError(http2ErrCodeProtocol)
+       }
+
+       err := f.ForeachSetting(func(s http2Setting) error {
                switch s.ID {
                case http2SettingMaxFrameSize:
                        cc.maxFrameSize = s.Val
@@ -6342,6 +6481,16 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error
                        cc.maxConcurrentStreams = s.Val
                case http2SettingInitialWindowSize:
 
+                       if s.Val > math.MaxInt32 {
+                               return http2ConnectionError(http2ErrCodeFlowControl)
+                       }
+
+                       delta := int32(s.Val) - int32(cc.initialWindowSize)
+                       for _, cs := range cc.streams {
+                               cs.flow.add(delta)
+                       }
+                       cc.cond.Broadcast()
+
                        cc.initialWindowSize = s.Val
                default:
 
@@ -6349,6 +6498,16 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error
                }
                return nil
        })
+       if err != nil {
+               return err
+       }
+
+       cc.wmu.Lock()
+       defer cc.wmu.Unlock()
+
+       cc.fr.WriteSettingsAck()
+       cc.bw.Flush()
+       return cc.werr
 }
 
 func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error {
@@ -6382,7 +6541,7 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er
        case <-cs.peerReset:
 
        default:
-               err := http2StreamError{cs.ID, f.ErrCode}
+               err := http2streamError(cs.ID, f.ErrCode)
                cs.resetErr = err
                close(cs.peerReset)
                cs.bufPipe.CloseWithError(err)
@@ -6560,6 +6719,12 @@ func (s http2bodyWriterState) scheduleBodyWrite() {
        }
 }
 
+// isConnectionCloseRequest reports whether req should use its own
+// connection for a single request and then close the connection.
+func http2isConnectionCloseRequest(req *Request) bool {
+       return req.Close || httplex.HeaderValuesContainsToken(req.Header["Connection"], "close")
+}
+
 // writeFramer is implemented by any type that is used to write frames.
 type http2writeFramer interface {
        writeFrame(http2writeContext) error
index 139ce3eafc7e1ff4a5de2d3ea35afe296a9dd015..13e5f283e4c4d6dbc8285e70ed64885583aed927 100644 (file)
@@ -4716,3 +4716,14 @@ func BenchmarkCloseNotifier(b *testing.B) {
        }
        b.StopTimer()
 }
+
+// Verify this doesn't race (Issue 16505)
+func TestConcurrentServerServe(t *testing.T) {
+       for i := 0; i < 100; i++ {
+               ln1 := &oneConnListener{conn: nil}
+               ln2 := &oneConnListener{conn: nil}
+               srv := Server{}
+               go func() { srv.Serve(ln1) }()
+               go func() { srv.Serve(ln2) }()
+       }
+}
index 7b2b4b2f42307245e887f9e3a82f390c6b3e69d7..89574a8b36e7e9300df967816621f2b59b35da0f 100644 (file)
@@ -2129,8 +2129,8 @@ type Server struct {
        ErrorLog *log.Logger
 
        disableKeepAlives int32     // accessed atomically.
-       nextProtoOnce     sync.Once // guards initialization of TLSNextProto in Serve
-       nextProtoErr      error
+       nextProtoOnce     sync.Once // guards setupHTTP2_* init
+       nextProtoErr      error     // result of http2.ConfigureServer if used
 }
 
 // A ConnState represents the state of a client connection to a server.
@@ -2260,10 +2260,8 @@ func (srv *Server) Serve(l net.Listener) error {
        }
        var tempDelay time.Duration // how long to sleep on accept failure
 
-       if srv.shouldConfigureHTTP2ForServe() {
-               if err := srv.setupHTTP2(); err != nil {
-                       return err
-               }
+       if err := srv.setupHTTP2_Serve(); err != nil {
+               return err
        }
 
        // TODO: allow changing base context? can't imagine concrete
@@ -2408,7 +2406,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
 
        // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig
        // before we clone it and create the TLS Listener.
-       if err := srv.setupHTTP2(); err != nil {
+       if err := srv.setupHTTP2_ListenAndServeTLS(); err != nil {
                return err
        }
 
@@ -2436,14 +2434,36 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
        return srv.Serve(tlsListener)
 }
 
-func (srv *Server) setupHTTP2() error {
+// setupHTTP2_ListenAndServeTLS conditionally configures HTTP/2 on
+// srv and returns whether there was an error setting it up. If it is
+// not configured for policy reasons, nil is returned.
+func (srv *Server) setupHTTP2_ListenAndServeTLS() error {
        srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults)
        return srv.nextProtoErr
 }
 
+// setupHTTP2_Serve is called from (*Server).Serve and conditionally
+// configures HTTP/2 on srv using a more conservative policy than
+// setupHTTP2_ListenAndServeTLS because Serve may be called
+// concurrently.
+//
+// The tests named TestTransportAutomaticHTTP2* and
+// TestConcurrentServerServe in server_test.go demonstrate some
+// of the supported use cases and motivations.
+func (srv *Server) setupHTTP2_Serve() error {
+       srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults_Serve)
+       return srv.nextProtoErr
+}
+
+func (srv *Server) onceSetNextProtoDefaults_Serve() {
+       if srv.shouldConfigureHTTP2ForServe() {
+               srv.onceSetNextProtoDefaults()
+       }
+}
+
 // onceSetNextProtoDefaults configures HTTP/2, if the user hasn't
 // configured otherwise. (by setting srv.TLSNextProto non-nil)
-// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2).
+// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*).
 func (srv *Server) onceSetNextProtoDefaults() {
        if strings.Contains(os.Getenv("GODEBUG"), "http2server=0") {
                return
index 9164d0d827c0017af81673c450f009622681e3e3..1f0763471b8e2cfed689101ea40c40cd0d93070b 100644 (file)
@@ -383,6 +383,11 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
                        return resp, nil
                }
                if !pconn.shouldRetryRequest(req, err) {
+                       // Issue 16465: return underlying net.Conn.Read error from peek,
+                       // as we've historically done.
+                       if e, ok := err.(transportReadFromServerError); ok {
+                               err = e.err
+                       }
                        return nil, err
                }
                testHookRoundTripRetried()
@@ -393,6 +398,15 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
 // HTTP request on a new connection. The non-nil input error is the
 // error from roundTrip.
 func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool {
+       if err == http2ErrNoCachedConn {
+               // Issue 16582: if the user started a bunch of
+               // requests at once, they can all pick the same conn
+               // and violate the server's max concurrent streams.
+               // Instead, match the HTTP/1 behavior for now and dial
+               // again to get a new TCP connection, rather than failing
+               // this request.
+               return true
+       }
        if err == errMissingHost {
                // User error.
                return false
@@ -415,11 +429,19 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool {
                // first, per golang.org/issue/15723
                return false
        }
-       if _, ok := err.(nothingWrittenError); ok {
+       switch err.(type) {
+       case nothingWrittenError:
                // We never wrote anything, so it's safe to retry.
                return true
+       case transportReadFromServerError:
+               // We got some non-EOF net.Conn.Read failure reading
+               // the 1st response byte from the server.
+               return true
        }
-       if err == errServerClosedIdle || err == errServerClosedConn {
+       if err == errServerClosedIdle {
+               // The server replied with io.EOF while we were trying to
+               // read the response. Probably an unfortunately keep-alive
+               // timeout, just as the client was writing a request.
                return true
        }
        return false // conservatively
@@ -476,8 +498,9 @@ func (t *Transport) CloseIdleConnections() {
 // CancelRequest cancels an in-flight request by closing its connection.
 // CancelRequest should only be called after RoundTrip has returned.
 //
-// Deprecated: Use Request.Cancel instead. CancelRequest cannot cancel
-// HTTP/2 requests.
+// Deprecated: Use Request.WithContext to create a request with a
+// cancelable context instead. CancelRequest cannot cancel HTTP/2
+// requests.
 func (t *Transport) CancelRequest(req *Request) {
        t.reqMu.Lock()
        cancel := t.reqCanceler[req]
@@ -566,10 +589,26 @@ var (
        errCloseIdleConns     = errors.New("http: CloseIdleConnections called")
        errReadLoopExiting    = errors.New("http: persistConn.readLoop exiting")
        errServerClosedIdle   = errors.New("http: server closed idle connection")
-       errServerClosedConn   = errors.New("http: server closed connection")
        errIdleConnTimeout    = errors.New("http: idle connection timeout")
+       errNotCachingH2Conn   = errors.New("http: not caching alternate protocol's connections")
 )
 
+// transportReadFromServerError is used by Transport.readLoop when the
+// 1 byte peek read fails and we're actually anticipating a response.
+// Usually this is just due to the inherent keep-alive shut down race,
+// where the server closed the connection at the same time the client
+// wrote. The underlying err field is usually io.EOF or some
+// ECONNRESET sort of thing which varies by platform. But it might be
+// the user's custom net.Conn.Read error too, so we carry it along for
+// them to return from Transport.RoundTrip.
+type transportReadFromServerError struct {
+       err error
+}
+
+func (e transportReadFromServerError) Error() string {
+       return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err)
+}
+
 func (t *Transport) putOrCloseIdleConn(pconn *persistConn) {
        if err := t.tryPutIdleConn(pconn); err != nil {
                pconn.close(err)
@@ -595,6 +634,9 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
        if pconn.isBroken() {
                return errConnBroken
        }
+       if pconn.alt != nil {
+               return errNotCachingH2Conn
+       }
        pconn.markReused()
        key := pconn.cacheKey
 
@@ -1293,7 +1335,10 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
        if pc.isCanceled() {
                return errRequestCanceled
        }
-       if err == errServerClosedIdle || err == errServerClosedConn {
+       if err == errServerClosedIdle {
+               return err
+       }
+       if _, ok := err.(transportReadFromServerError); ok {
                return err
        }
        if pc.isBroken() {
@@ -1314,7 +1359,11 @@ func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) err
                return errRequestCanceled
        }
        err := pc.closed
-       if err == errServerClosedIdle || err == errServerClosedConn {
+       if err == errServerClosedIdle {
+               // Don't decorate
+               return err
+       }
+       if _, ok := err.(transportReadFromServerError); ok {
                // Don't decorate
                return err
        }
@@ -1383,7 +1432,7 @@ func (pc *persistConn) readLoop() {
                if err == nil {
                        resp, err = pc.readResponse(rc, trace)
                } else {
-                       err = errServerClosedConn
+                       err = transportReadFromServerError{err}
                        closeErr = err
                }
 
@@ -1784,6 +1833,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        var re responseAndError
        var respHeaderTimer <-chan time.Time
        cancelChan := req.Request.Cancel
+       ctxDoneChan := req.Context().Done()
 WaitResponse:
        for {
                testHookWaitResLoop()
@@ -1815,9 +1865,11 @@ WaitResponse:
                case <-cancelChan:
                        pc.t.CancelRequest(req.Request)
                        cancelChan = nil
-               case <-req.Context().Done():
+                       ctxDoneChan = nil
+               case <-ctxDoneChan:
                        pc.t.CancelRequest(req.Request)
                        cancelChan = nil
+                       ctxDoneChan = nil
                }
        }
 
index a157d906300a2cd87d4ca7a0788cd0f92bf35cbd..a05ca6ed0d869adc29133824cd48be0fb87770e4 100644 (file)
@@ -46,17 +46,22 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
        conn.Close() // simulate the server hanging up on the client
 
        _, err = pc.roundTrip(treq)
-       if err != errServerClosedConn && err != errServerClosedIdle {
+       if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
                t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
        }
 
        <-pc.closech
        err = pc.closed
-       if err != errServerClosedConn && err != errServerClosedIdle {
+       if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
                t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
        }
 }
 
+func isTransportReadFromServerError(err error) bool {
+       _, ok := err.(transportReadFromServerError)
+       return ok
+}
+
 func newLocalListener(t *testing.T) net.Listener {
        ln, err := net.Listen("tcp", "127.0.0.1:0")
        if err != nil {
index 72b98f16d7eaa0ca9b726034c807e5fd0daa598a..298682d04de93a4f727063d5dc20ad509dd586fd 100644 (file)
@@ -3511,6 +3511,100 @@ func TestTransportIdleConnTimeout(t *testing.T) {
        }
 }
 
+// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
+// HTTP/2 connection was established but but its caller no longer
+// wanted it. (Assuming the connection cache was enabled, which it is
+// by default)
+//
+// This test reproduced the crash by setting the IdleConnTimeout low
+// (to make the test reasonable) and then making a request which is
+// canceled by the DialTLS hook, which then also waits to return the
+// real connection until after the RoundTrip saw the error.  Then we
+// know the successful tls.Dial from DialTLS will need to go into the
+// idle pool. Then we give it a of time to explode.
+func TestIdleConnH2Crash(t *testing.T) {
+       cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               // nothing
+       }))
+       defer cst.close()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       gotErr := make(chan bool, 1)
+
+       cst.tr.IdleConnTimeout = 5 * time.Millisecond
+       cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
+               cancel()
+               <-gotErr
+               c, err := tls.Dial(network, addr, &tls.Config{
+                       InsecureSkipVerify: true,
+                       NextProtos:         []string{"h2"},
+               })
+               if err != nil {
+                       t.Error(err)
+                       return nil, err
+               }
+               if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
+                       t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
+                       c.Close()
+                       return nil, errors.New("bogus")
+               }
+               return c, nil
+       }
+
+       req, _ := NewRequest("GET", cst.ts.URL, nil)
+       req = req.WithContext(ctx)
+       res, err := cst.c.Do(req)
+       if err == nil {
+               res.Body.Close()
+               t.Fatal("unexpected success")
+       }
+       gotErr <- true
+
+       // Wait for the explosion.
+       time.Sleep(cst.tr.IdleConnTimeout * 10)
+}
+
+type funcConn struct {
+       net.Conn
+       read  func([]byte) (int, error)
+       write func([]byte) (int, error)
+}
+
+func (c funcConn) Read(p []byte) (int, error)  { return c.read(p) }
+func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
+func (c funcConn) Close() error                { return nil }
+
+// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
+// back to the caller.
+func TestTransportReturnsPeekError(t *testing.T) {
+       errValue := errors.New("specific error value")
+
+       wrote := make(chan struct{})
+       var wroteOnce sync.Once
+
+       tr := &Transport{
+               Dial: func(network, addr string) (net.Conn, error) {
+                       c := funcConn{
+                               read: func([]byte) (int, error) {
+                                       <-wrote
+                                       return 0, errValue
+                               },
+                               write: func(p []byte) (int, error) {
+                                       wroteOnce.Do(func() { close(wrote) })
+                                       return len(p), nil
+                               },
+                       }
+                       return c, nil
+               },
+       }
+       _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
+       if err != errValue {
+               t.Errorf("error = %#v; want %v", err, errValue)
+       }
+}
+
 var errFakeRoundTrip = errors.New("fake roundtrip")
 
 type funcRoundTripper func()
index 5dbd7f97662989d8b977277da4bc92daf80aa4e2..74b7494c0de8c2be91275b0debf20faa47856763 100644 (file)
@@ -28,6 +28,12 @@ func (p *Process) blockUntilWaitable() (bool, error) {
        _, _, e := syscall.Syscall6(syscall.SYS_WAITID, _P_PID, uintptr(p.Pid), uintptr(unsafe.Pointer(psig)), syscall.WEXITED|syscall.WNOWAIT, 0, 0)
        runtime.KeepAlive(psig)
        if e != 0 {
+               // waitid has been available since Linux 2.6.9, but
+               // reportedly is not available in Ubuntu on Windows.
+               // See issue 16610.
+               if e == syscall.ENOSYS {
+                       return false, nil
+               }
                return false, NewSyscallError("waitid", e)
        }
        return true, nil
index 8ca007f70aac65173af440573861a1a4fe485278..a7e2e6422bc34fdcb33807998293f2aa734a993f 100644 (file)
@@ -4,4 +4,7 @@
 
 package filepath
 
-var ToNorm = toNorm
+var (
+       ToNorm   = toNorm
+       NormBase = normBase
+)
index 4d5e3bdcb6d620d758217ee023352ea0c69fffd1..9c3f287ecbd90e88ca784ac62bf9993e959569b7 100644 (file)
@@ -843,7 +843,7 @@ func TestEvalSymlinks(t *testing.T) {
                if p, err := filepath.EvalSymlinks(path); err != nil {
                        t.Errorf("EvalSymlinks(%q) error: %v", d.path, err)
                } else if filepath.Clean(p) != filepath.Clean(dest) {
-                       t.Errorf("Clean(%q)=%q, want %q", path, p, dest)
+                       t.Errorf("EvalSymlinks(%q)=%q, want %q", path, p, dest)
                }
 
                // test EvalSymlinks(".")
@@ -875,6 +875,34 @@ func TestEvalSymlinks(t *testing.T) {
                        t.Errorf(`EvalSymlinks(".") in %q directory returns %q, want "." or %q`, d.path, p, want)
                }()
 
+               // test EvalSymlinks(".."+path)
+               func() {
+                       defer func() {
+                               err := os.Chdir(wd)
+                               if err != nil {
+                                       t.Fatal(err)
+                               }
+                       }()
+
+                       err := os.Chdir(simpleJoin(tmpDir, "test"))
+                       if err != nil {
+                               t.Error(err)
+                               return
+                       }
+
+                       path := simpleJoin("..", d.path)
+                       dest := simpleJoin("..", d.dest)
+                       if filepath.IsAbs(d.dest) || os.IsPathSeparator(d.dest[0]) {
+                               dest = d.dest
+                       }
+
+                       if p, err := filepath.EvalSymlinks(path); err != nil {
+                               t.Errorf("EvalSymlinks(%q) error: %v", d.path, err)
+                       } else if filepath.Clean(p) != filepath.Clean(dest) {
+                               t.Errorf("EvalSymlinks(%q)=%q, want %q", path, p, dest)
+                       }
+               }()
+
                // test EvalSymlinks where parameter is relative path
                func() {
                        defer func() {
@@ -892,7 +920,7 @@ func TestEvalSymlinks(t *testing.T) {
                        if p, err := filepath.EvalSymlinks(d.path); err != nil {
                                t.Errorf("EvalSymlinks(%q) error: %v", d.path, err)
                        } else if filepath.Clean(p) != filepath.Clean(d.dest) {
-                               t.Errorf("Clean(%q)=%q, want %q", d.path, p, d.dest)
+                               t.Errorf("EvalSymlinks(%q)=%q, want %q", d.path, p, d.dest)
                        }
                }()
        }
index 243352819e0d23d66cb4b80d9cd94c2fcd3e320d..bb05aabc924292e4d084700c11710300b831c9ef 100644 (file)
@@ -22,7 +22,7 @@ func normVolumeName(path string) string {
        return strings.ToUpper(volume)
 }
 
-// normBase retruns the last element of path.
+// normBase returns the last element of path with correct case.
 func normBase(path string) (string, error) {
        p, err := syscall.UTF16PtrFromString(path)
        if err != nil {
@@ -40,7 +40,24 @@ func normBase(path string) (string, error) {
        return syscall.UTF16ToString(data.FileName[:]), nil
 }
 
-func toNorm(path string, base func(string) (string, error)) (string, error) {
+// baseIsDotDot returns whether the last element of path is "..".
+// The given path should be 'Clean'-ed in advance.
+func baseIsDotDot(path string) bool {
+       i := strings.LastIndexByte(path, Separator)
+       return path[i+1:] == ".."
+}
+
+// toNorm returns the normalized path that is guranteed to be unique.
+// It should accept the following formats:
+//   * UNC paths                              (e.g \\server\share\foo\bar)
+//   * absolute paths                         (e.g C:\foo\bar)
+//   * relative paths begin with drive letter (e.g C:foo\bar, C:..\foo\bar, C:.., C:.)
+//   * relative paths begin with '\'          (e.g \foo\bar)
+//   * relative paths begin without '\'       (e.g foo\bar, ..\foo\bar, .., .)
+// The returned normalized path will be in the same form (of 5 listed above) as the input path.
+// If two paths A and B are indicating the same file with the same format, toNorm(A) should be equal to toNorm(B).
+// The normBase parameter should be equal to the normBase func, except for in tests.  See docs on the normBase func.
+func toNorm(path string, normBase func(string) (string, error)) (string, error) {
        if path == "" {
                return path, nil
        }
@@ -58,7 +75,13 @@ func toNorm(path string, base func(string) (string, error)) (string, error) {
        var normPath string
 
        for {
-               name, err := base(volume + path)
+               if baseIsDotDot(path) {
+                       normPath = path + `\` + normPath
+
+                       break
+               }
+
+               name, err := normBase(volume + path)
                if err != nil {
                        return "", err
                }
index 7045e448ea386212b5f912c6d649ada9d859d6c4..e601c2cbde20894b6aa81aab393fe3526e96bece 100644 (file)
@@ -5752,6 +5752,8 @@ func TestTypeStrings(t *testing.T) {
                {TypeOf(new(XM)), "*reflect_test.XM"},
                {TypeOf(new(XM).String), "func() string"},
                {TypeOf(new(XM)).Method(0).Type, "func(*reflect_test.XM) string"},
+               {ChanOf(3, TypeOf(XM{})), "chan reflect_test.XM"},
+               {MapOf(TypeOf(int(0)), TypeOf(XM{})), "map[int]reflect_test.XM"},
        }
 
        for i, test := range stringTests {
index d2e5b63fedbc11d65c303ab8aa48b2b34fc8305f..0a58bafd85f0fd9fdb0cf6e820596e9fa6b90ae9 100644 (file)
@@ -4,8 +4,69 @@
 
 // Package pprof writes runtime profiling data in the format expected
 // by the pprof visualization tool.
+//
+// Profiling a Go program
+//
+// The first step to profiling a Go program is to enable profiling.
+// Support for profiling benchmarks built with the standard testing
+// package is built into go test. For example, the following command
+// runs benchmarks in the current directory and writes the CPU and
+// memory profiles to cpu.prof and mem.prof:
+//
+//     go test -cpuprofile cpu.prof -memprofile mem.prof -bench .
+//
+// To add equivalent profiling support to a standalone program, add
+// code like the following to your main function:
+//
+//    var cpuprofile = flag.String("cpuprofile", "", "write cpu profile `file`")
+//    var memprofile = flag.String("memprofile", "", "write memory profile to `file`")
+//
+//    func main() {
+//        flag.Parse()
+//        if *cpuprofile != "" {
+//            f, err := os.Create(*cpuprofile)
+//            if err != nil {
+//                log.Fatal("could not create CPU profile: ", err)
+//            }
+//            if err := pprof.StartCPUProfile(f); err != nil {
+//                log.Fatal("could not start CPU profile: ", err)
+//            }
+//            defer pprof.StopCPUProfile()
+//        }
+//        ...
+//        if *memprofile != "" {
+//            f, err := os.Create(*memprofile)
+//            if err != nil {
+//                log.Fatal("could not create memory profile: ", err)
+//            }
+//            runtime.GC() // get up-to-date statistics
+//            if err := pprof.WriteHeapProfile(f); err != nil {
+//                log.Fatal("could not write memory profile: ", err)
+//            }
+//            f.Close()
+//        }
+//    }
+//
+// There is also a standard HTTP interface to profiling data. Adding
+// the following line will install handlers under the /debug/pprof/
+// URL to download live profiles:
+//
+//    import _ "net/http/pprof"
+//
+// See the net/http/pprof package for more details.
+//
+// Profiles can then be visualized with the pprof tool:
+//
+//    go tool pprof cpu.prof
+//
+// There are many commands available from the pprof command line.
+// Commonly used commands include "top", which prints a summary of the
+// top program hot-spots, and "web", which opens an interactive graph
+// of hot-spots and their call graphs. Use "help" for information on
+// all pprof commands.
+//
 // For more information about pprof, see
-// http://github.com/google/pprof/.
+// https://github.com/google/pprof/blob/master/doc/pprof.md.
 package pprof
 
 import (
diff --git a/libgo/go/syscall/syscall_darwin_test.go b/libgo/go/syscall/syscall_darwin_test.go
new file mode 100644 (file)
index 0000000..cea5636
--- /dev/null
@@ -0,0 +1,23 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin
+// +build amd64 386 arm arm64
+
+package syscall_test
+
+import (
+       "syscall"
+       "testing"
+)
+
+func TestDarwinGettimeofday(t *testing.T) {
+       tv := &syscall.Timeval{}
+       if err := syscall.Gettimeofday(tv); err != nil {
+               t.Fatal(err)
+       }
+       if tv.Sec == 0 && tv.Usec == 0 {
+               t.Fatal("Sec and Usec both zero")
+       }
+}