Update to current version of Go library.
authorIan Lance Taylor <ian@gcc.gnu.org>
Thu, 24 Mar 2011 23:46:17 +0000 (23:46 +0000)
committerIan Lance Taylor <ian@gcc.gnu.org>
Thu, 24 Mar 2011 23:46:17 +0000 (23:46 +0000)
From-SVN: r171427

172 files changed:
libgo/MERGE
libgo/Makefile.am
libgo/Makefile.in
libgo/go/archive/zip/reader.go
libgo/go/big/int.go
libgo/go/big/int_test.go
libgo/go/big/nat.go
libgo/go/bufio/bufio_test.go
libgo/go/compress/flate/deflate_test.go
libgo/go/compress/lzw/reader_test.go
libgo/go/compress/lzw/writer_test.go
libgo/go/crypto/ecdsa/ecdsa.go [new file with mode: 0644]
libgo/go/crypto/ecdsa/ecdsa_test.go [new file with mode: 0644]
libgo/go/crypto/elliptic/elliptic.go
libgo/go/crypto/openpgp/packet/packet.go
libgo/go/crypto/openpgp/packet/packet_test.go
libgo/go/crypto/openpgp/packet/private_key.go
libgo/go/crypto/openpgp/packet/public_key.go
libgo/go/crypto/openpgp/packet/public_key_test.go
libgo/go/crypto/openpgp/packet/signature.go
libgo/go/crypto/openpgp/read_test.go
libgo/go/crypto/openpgp/write.go
libgo/go/crypto/openpgp/write_test.go
libgo/go/crypto/tls/common.go
libgo/go/crypto/tls/conn.go
libgo/go/crypto/tls/generate_cert.go
libgo/go/debug/proc/proc_darwin.go
libgo/go/debug/proc/proc_freebsd.go
libgo/go/debug/proc/proc_linux.go
libgo/go/debug/proc/proc_windows.go
libgo/go/exec/exec.go
libgo/go/exec/exec_test.go
libgo/go/exp/eval/stmt.go
libgo/go/exp/eval/stmt_test.go
libgo/go/exp/ogle/cmd.go
libgo/go/expvar/expvar.go
libgo/go/flag/flag.go
libgo/go/flag/flag_test.go
libgo/go/fmt/format.go
libgo/go/fmt/scan.go
libgo/go/fmt/scan_test.go
libgo/go/go/ast/ast.go
libgo/go/go/ast/filter.go
libgo/go/go/ast/print.go
libgo/go/go/ast/scope.go
libgo/go/go/ast/walk.go
libgo/go/go/parser/interface.go
libgo/go/go/parser/parser.go
libgo/go/go/parser/parser_test.go
libgo/go/go/printer/nodes.go
libgo/go/go/printer/printer.go
libgo/go/go/printer/printer_test.go
libgo/go/go/printer/testdata/expressions.golden
libgo/go/go/printer/testdata/expressions.input
libgo/go/go/printer/testdata/expressions.raw
libgo/go/go/printer/testdata/slow.golden [new file with mode: 0644]
libgo/go/go/printer/testdata/slow.input [new file with mode: 0644]
libgo/go/go/scanner/scanner.go
libgo/go/go/scanner/scanner_test.go
libgo/go/go/typechecker/scope.go
libgo/go/go/typechecker/testdata/test0.go [deleted file]
libgo/go/go/typechecker/testdata/test0.src [new file with mode: 0644]
libgo/go/go/typechecker/testdata/test1.go [deleted file]
libgo/go/go/typechecker/testdata/test1.src [new file with mode: 0644]
libgo/go/go/typechecker/testdata/test3.go [deleted file]
libgo/go/go/typechecker/testdata/test3.src [new file with mode: 0644]
libgo/go/go/typechecker/testdata/test4.go [deleted file]
libgo/go/go/typechecker/testdata/test4.src [new file with mode: 0644]
libgo/go/go/typechecker/type.go [new file with mode: 0644]
libgo/go/go/typechecker/typechecker.go
libgo/go/go/typechecker/typechecker_test.go
libgo/go/go/typechecker/universe.go
libgo/go/gob/codec_test.go
libgo/go/gob/decode.go
libgo/go/gob/decoder.go
libgo/go/gob/encode.go
libgo/go/gob/encoder.go
libgo/go/gob/gobencdec_test.go [new file with mode: 0644]
libgo/go/gob/timing_test.go [new file with mode: 0644]
libgo/go/gob/type.go
libgo/go/gob/type_test.go
libgo/go/hash/fnv/fnv.go [new file with mode: 0644]
libgo/go/hash/fnv/fnv_test.go [new file with mode: 0644]
libgo/go/http/cgi/child.go [new file with mode: 0644]
libgo/go/http/cgi/child_test.go [new file with mode: 0644]
libgo/go/http/cgi/host.go [new file with mode: 0644]
libgo/go/http/cgi/host_test.go [new file with mode: 0644]
libgo/go/http/cgi/matryoshka_test.go [new file with mode: 0644]
libgo/go/http/client.go
libgo/go/http/client_test.go
libgo/go/http/cookie.go [new file with mode: 0644]
libgo/go/http/cookie_test.go [new file with mode: 0644]
libgo/go/http/dump.go
libgo/go/http/export_test.go [new file with mode: 0644]
libgo/go/http/fs.go
libgo/go/http/fs_test.go
libgo/go/http/httptest/recorder.go [new file with mode: 0644]
libgo/go/http/httptest/server.go [new file with mode: 0644]
libgo/go/http/persist.go
libgo/go/http/pprof/pprof.go
libgo/go/http/proxy_test.go
libgo/go/http/range_test.go [new file with mode: 0644]
libgo/go/http/readrequest_test.go
libgo/go/http/request.go
libgo/go/http/request_test.go
libgo/go/http/requestwrite_test.go
libgo/go/http/response.go
libgo/go/http/responsewrite_test.go
libgo/go/http/serve_test.go
libgo/go/http/server.go
libgo/go/http/transport.go
libgo/go/http/transport_test.go [new file with mode: 0644]
libgo/go/io/ioutil/ioutil.go
libgo/go/io/ioutil/tempfile.go
libgo/go/io/ioutil/tempfile_test.go
libgo/go/io/pipe.go
libgo/go/mime/multipart/multipart.go
libgo/go/mime/multipart/multipart_test.go
libgo/go/net/fd.go
libgo/go/net/fd_linux.go
libgo/go/net/ip.go
libgo/go/netchan/common.go
libgo/go/netchan/export.go
libgo/go/netchan/import.go
libgo/go/netchan/netchan_test.go
libgo/go/os/exec.go
libgo/go/os/inotify/inotify_linux_test.go
libgo/go/os/os_test.go
libgo/go/path/filepath/match.go [new file with mode: 0644]
libgo/go/path/filepath/match_test.go [new file with mode: 0644]
libgo/go/path/filepath/path.go [new file with mode: 0644]
libgo/go/path/filepath/path_test.go [new file with mode: 0644]
libgo/go/path/filepath/path_unix.go [new file with mode: 0644]
libgo/go/path/filepath/path_windows.go [new file with mode: 0644]
libgo/go/path/match.go
libgo/go/path/match_test.go
libgo/go/path/path.go
libgo/go/path/path_test.go
libgo/go/path/path_unix.go [deleted file]
libgo/go/path/path_windows.go [deleted file]
libgo/go/reflect/all_test.go
libgo/go/reflect/value.go
libgo/go/rpc/client.go
libgo/go/rpc/server.go
libgo/go/rpc/server_test.go
libgo/go/runtime/debug.go
libgo/go/runtime/mem.go [new file with mode: 0644]
libgo/go/runtime/pprof/pprof.go
libgo/go/runtime/pprof/pprof_test.go [new file with mode: 0644]
libgo/go/strings/strings.go
libgo/go/strings/strings_test.go
libgo/go/sync/waitgroup.go
libgo/go/template/template.go
libgo/go/testing/script/script.go
libgo/go/testing/testing.go
libgo/go/time/sleep.go
libgo/go/time/sleep_test.go
libgo/go/time/sys.go [new file with mode: 0644]
libgo/go/time/time.go
libgo/go/time/time_test.go
libgo/go/websocket/server.go
libgo/go/websocket/websocket_test.go
libgo/go/xml/xml.go
libgo/merge.sh
libgo/runtime/channel.h
libgo/runtime/go-rec-big.c
libgo/runtime/go-rec-nb-big.c
libgo/runtime/go-rec-nb-small.c
libgo/runtime/go-rec-small.c
libgo/runtime/go-reflect-chan.c
libgo/syscalls/exec.go
libgo/testsuite/gotest

index e572b232450bea48d7a3b27a1292a3d9490bec39..729be06029097a2dbe0418ae1a25b060064ce3e1 100644 (file)
@@ -1,4 +1,4 @@
-94d654be2064
+31d7feb9281b
 
 The first line of this file holds the Mercurial revision number of the
 last merge done from the master library sources.
index 2db29fdea527ecabbb3bae0ee9e88580d753a28d..77bf27a7c53af9a48559614b4a7fad3d7020d650 100644 (file)
@@ -182,6 +182,7 @@ toolexeclibgocrypto_DATA = \
        crypto/cast5.gox \
        crypto/cipher.gox \
        crypto/dsa.gox \
+       crypto/ecdsa.gox \
        crypto/elliptic.gox \
        crypto/hmac.gox \
        crypto/md4.gox \
@@ -254,11 +255,14 @@ toolexeclibgohashdir = $(toolexeclibgodir)/hash
 toolexeclibgohash_DATA = \
        hash/adler32.gox \
        hash/crc32.gox \
-       hash/crc64.gox
+       hash/crc64.gox \
+       hash/fnv.gox
 
 toolexeclibgohttpdir = $(toolexeclibgodir)/http
 
 toolexeclibgohttp_DATA = \
+       http/cgi.gox \
+       http/httptest.gox \
        http/pprof.gox
 
 toolexeclibgoimagedir = $(toolexeclibgodir)/image
@@ -301,6 +305,11 @@ toolexeclibgoos_DATA = \
        $(os_inotify_gox) \
        os/signal.gox
 
+toolexeclibgopathdir = $(toolexeclibgodir)/path
+
+toolexeclibgopath_DATA = \
+       path/filepath.gox
+
 toolexeclibgorpcdir = $(toolexeclibgodir)/rpc
 
 toolexeclibgorpc_DATA = \
@@ -543,6 +552,7 @@ go_html_files = \
 go_http_files = \
        go/http/chunked.go \
        go/http/client.go \
+       go/http/cookie.go \
        go/http/dump.go \
        go/http/fs.go \
        go/http/header.go \
@@ -726,8 +736,7 @@ go_patch_files = \
 
 go_path_files = \
        go/path/match.go \
-       go/path/path.go \
-       go/path/path_unix.go
+       go/path/path.go
 
 go_rand_files = \
        go/rand/exp.go \
@@ -753,6 +762,7 @@ go_runtime_files = \
        go/runtime/debug.go \
        go/runtime/error.go \
        go/runtime/extern.go \
+       go/runtime/mem.go \
        go/runtime/sig.go \
        go/runtime/softfloat64.go \
        go/runtime/type.go \
@@ -826,6 +836,7 @@ go_testing_files = \
 go_time_files = \
        go/time/format.go \
        go/time/sleep.go \
+       go/time/sys.go \
        go/time/tick.go \
        go/time/time.go \
        go/time/zoneinfo_unix.go
@@ -936,6 +947,8 @@ go_crypto_cipher_files = \
        go/crypto/cipher/ofb.go
 go_crypto_dsa_files = \
        go/crypto/dsa/dsa.go
+go_crypto_ecdsa_files = \
+       go/crypto/ecdsa/ecdsa.go
 go_crypto_elliptic_files = \
        go/crypto/elliptic/elliptic.go
 go_crypto_hmac_files = \
@@ -1101,6 +1114,7 @@ go_go_token_files = \
        go/go/token/token.go
 go_go_typechecker_files = \
        go/go/typechecker/scope.go \
+       go/go/typechecker/type.go \
        go/go/typechecker/typechecker.go \
        go/go/typechecker/universe.go
 
@@ -1110,7 +1124,15 @@ go_hash_crc32_files = \
        go/hash/crc32/crc32.go
 go_hash_crc64_files = \
        go/hash/crc64/crc64.go
-
+go_hash_fnv_files = \
+       go/hash/fnv/fnv.go
+
+go_http_cgi_files = \
+       go/http/cgi/child.go \
+       go/http/cgi/host.go
+go_http_httptest_files = \
+       go/http/httptest/recorder.go \
+       go/http/httptest/server.go
 go_http_pprof_files = \
        go/http/pprof/pprof.go
 
@@ -1151,6 +1173,11 @@ go_os_signal_files = \
        go/os/signal/signal.go \
        unix.go
 
+go_path_filepath_files = \
+       go/path/filepath/match.go \
+       go/path/filepath/path.go \
+       go/path/filepath/path_unix.go
+
 go_rpc_jsonrpc_files = \
        go/rpc/jsonrpc/client.go \
        go/rpc/jsonrpc/server.go
@@ -1377,6 +1404,7 @@ libgo_go_objs = \
        crypto/cast5.lo \
        crypto/cipher.lo \
        crypto/dsa.lo \
+       crypto/ecdsa.lo \
        crypto/elliptic.lo \
        crypto/hmac.lo \
        crypto/md4.lo \
@@ -1426,6 +1454,9 @@ libgo_go_objs = \
        hash/adler32.lo \
        hash/crc32.lo \
        hash/crc64.lo \
+       hash/fnv.lo \
+       http/cgi.lo \
+       http/httptest.lo \
        http/pprof.lo \
        image/jpeg.lo \
        image/png.lo \
@@ -1436,6 +1467,7 @@ libgo_go_objs = \
        net/textproto.lo \
        $(os_lib_inotify_lo) \
        os/signal.lo \
+       path/filepath.lo \
        rpc/jsonrpc.lo \
        runtime/debug.lo \
        runtime/pprof.lo \
@@ -1532,7 +1564,7 @@ asn1/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: asn1/check
 
-big/big.lo: $(go_big_files) fmt.gox rand.gox strings.gox
+big/big.lo: $(go_big_files) fmt.gox rand.gox strings.gox os.gox
        $(BUILDPACKAGE)
 big/check: $(CHECK_DEPS)
        $(CHECK)
@@ -1597,9 +1629,9 @@ fmt/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: fmt/check
 
-gob/gob.lo: $(go_gob_files) bytes.gox fmt.gox io.gox math.gox os.gox \
-               reflect.gox runtime.gox strings.gox sync.gox unicode.gox \
-               utf8.gox
+gob/gob.lo: $(go_gob_files) bufio.gox bytes.gox fmt.gox io.gox math.gox \
+               os.gox reflect.gox runtime.gox strings.gox sync.gox \
+               unicode.gox utf8.gox
        $(BUILDPACKAGE)
 gob/check: $(CHECK_DEPS)
        $(CHECK)
@@ -1621,8 +1653,8 @@ html/check: $(CHECK_DEPS)
 http/http.lo: $(go_http_files) bufio.gox bytes.gox container/vector.gox \
                crypto/rand.gox crypto/tls.gox encoding/base64.gox fmt.gox \
                io.gox io/ioutil.gox log.gox mime.gox mime/multipart.gox \
-               net.gox net/textproto.gox os.gox path.gox sort.gox \
-               strconv.gox strings.gox sync.gox time.gox utf8.gox
+               net.gox net/textproto.gox os.gox path.gox path/filepath.gox \
+               sort.gox strconv.gox strings.gox sync.gox time.gox utf8.gox
        $(BUILDPACKAGE)
 http/check: $(CHECK_DEPS)
        $(CHECK)
@@ -1634,7 +1666,7 @@ image/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: image/check
 
-io/io.lo: $(go_io_files) os.gox runtime.gox sync.gox
+io/io.lo: $(go_io_files) os.gox sync.gox
        $(BUILDPACKAGE)
 io/check: $(CHECK_DEPS)
        $(CHECK)
@@ -1697,8 +1729,7 @@ patch/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: patch/check
 
-path/path.lo: $(go_path_files) io/ioutil.gox os.gox sort.gox strings.gox \
-               utf8.gox
+path/path.lo: $(go_path_files) os.gox strings.gox utf8.gox
        $(BUILDPACKAGE)
 path/check: $(CHECK_DEPS)
        $(CHECK)
@@ -1799,7 +1830,7 @@ template/check: $(CHECK_DEPS)
 .PHONY: template/check
 
 testing/testing.lo: $(go_testing_files) flag.gox fmt.gox os.gox regexp.gox \
-               runtime.gox time.gox
+               runtime.gox runtime/pprof.gox time.gox
        $(BUILDPACKAGE)
 testing/check: $(CHECK_DEPS)
        $(CHECK)
@@ -1862,7 +1893,7 @@ archive/tar/check: $(CHECK_DEPS)
 
 archive/zip.lo: $(go_archive_zip_files) bufio.gox bytes.gox \
                compress/flate.gox hash.gox hash/crc32.gox \
-               encoding/binary.gox io.gox os.gox
+               encoding/binary.gox io.gox io/ioutil.gox os.gox
        $(BUILDPACKAGE)
 archive/zip/check: $(CHECK_DEPS)
        @$(MKDIR_P) archive/zip
@@ -1977,6 +2008,14 @@ crypto/dsa/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: crypto/dsa/check
 
+crypto/ecdsa.lo: $(go_crypto_ecdsa_files) big.gox crypto/elliptic.gox io.gox \
+               os.gox
+       $(BUILDPACKAGE)
+crypto/ecdsa/check: $(CHECK_DEPS)
+       @$(MKDIR_P) crypto/ecdsa
+       $(CHECK)
+.PHONY: crypto/ecdsa/check
+
 crypto/elliptic.lo: $(go_crypto_elliptic_files) big.gox io.gox os.gox sync.gox
        $(BUILDPACKAGE)
 crypto/elliptic/check: $(CHECK_DEPS)
@@ -2014,8 +2053,8 @@ crypto/ocsp/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: crypto/ocsp/check
 
-crypto/openpgp.lo: $(go_crypto_openpgp_files) crypto.gox \
-                crypto/openpgp/armor.gox crypto/openpgp/error.gox \
+crypto/openpgp.lo: $(go_crypto_openpgp_files) crypto.gox crypto/dsa.gox \
+               crypto/openpgp/armor.gox crypto/openpgp/error.gox \
                crypto/openpgp/packet.gox crypto/rsa.gox crypto/sha256.gox \
                hash.gox io.gox os.gox strconv.gox time.gox
        $(BUILDPACKAGE)
@@ -2137,10 +2176,10 @@ crypto/openpgp/error/check: $(CHECK_DEPS)
 crypto/openpgp/packet.lo: $(go_crypto_openpgp_packet_files) big.gox bytes.gox \
                compress/flate.gox compress/zlib.gox crypto.gox \
                crypto/aes.gox crypto/cast5.gox crypto/cipher.gox \
-               crypto/openpgp/error.gox crypto/openpgp/s2k.gox \
-               crypto/rand.gox crypto/rsa.gox crypto/sha1.gox \
-               crypto/subtle.gox encoding/binary.gox hash.gox io.gox \
-               io/ioutil.gox os.gox strconv.gox strings.gox
+               crypto/dsa.gox crypto/openpgp/error.gox \
+               crypto/openpgp/s2k.gox crypto/rand.gox crypto/rsa.gox \
+               crypto/sha1.gox crypto/subtle.gox encoding/binary.gox fmt.gox \
+               hash.gox io.gox io/ioutil.gox os.gox strconv.gox strings.gox
        $(BUILDPACKAGE)
 crypto/openpgp/packet/check: $(CHECK_DEPS)
        @$(MKDIR_P) crypto/openpgp/packet
@@ -2288,8 +2327,8 @@ exp/eval/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: exp/eval/check
 
-go/ast.lo: $(go_go_ast_files) fmt.gox go/token.gox io.gox os.gox reflect.gox \
-               unicode.gox utf8.gox
+go/ast.lo: $(go_go_ast_files) bytes.gox fmt.gox go/token.gox io.gox os.gox \
+               reflect.gox unicode.gox utf8.gox
        $(BUILDPACKAGE)
 go/ast/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/ast
@@ -2306,7 +2345,7 @@ go/doc/check: $(CHECK_DEPS)
 
 go/parser.lo: $(go_go_parser_files) bytes.gox fmt.gox go/ast.gox \
                go/scanner.gox go/token.gox io.gox io/ioutil.gox os.gox \
-               path.gox strings.gox
+               path/filepath.gox strings.gox
        $(BUILDPACKAGE)
 go/parser/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/parser
@@ -2314,8 +2353,8 @@ go/parser/check: $(CHECK_DEPS)
 .PHONY: go/parser/check
 
 go/printer.lo: $(go_go_printer_files) bytes.gox fmt.gox go/ast.gox \
-               go/token.gox io.gox os.gox reflect.gox runtime.gox \
-               strings.gox tabwriter.gox
+               go/token.gox io.gox os.gox path/filepath.gox reflect.gox \
+               runtime.gox strings.gox tabwriter.gox
        $(BUILDPACKAGE)
 go/printer/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/printer
@@ -2323,8 +2362,8 @@ go/printer/check: $(CHECK_DEPS)
 .PHONY: go/printer/check
 
 go/scanner.lo: $(go_go_scanner_files) bytes.gox container/vector.gox fmt.gox \
-               go/token.gox io.gox os.gox path.gox sort.gox strconv.gox \
-               unicode.gox utf8.gox
+               go/token.gox io.gox os.gox path/filepath.gox sort.gox \
+               strconv.gox unicode.gox utf8.gox
        $(BUILDPACKAGE)
 go/scanner/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/scanner
@@ -2367,6 +2406,30 @@ hash/crc64/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: hash/crc64/check
 
+hash/fnv.lo: $(go_hash_fnv_files) encoding/binary.gox hash.gox os.gox
+       $(BUILDPACKAGE)
+hash/fnv/check: $(CHECK_DEPS)
+       @$(MKDIR_P) hash/fnv
+       $(CHECK)
+.PHONY: hash/fnv/check
+
+http/cgi.lo: $(go_http_cgi_files) bufio.gox bytes.gox encoding/line.gox \
+               exec.gox fmt.gox http.gox io.gox io/ioutil.gox log.gox \
+               os.gox path/filepath.gox regexp.gox strconv.gox strings.gox
+       $(BUILDPACKAGE)
+http/cgi/check: $(CHECK_DEPS)
+       @$(MKDIR_P) http/cgi
+       $(CHECK)
+.PHONY: http/cgi/check
+
+http/httptest.lo: $(go_http_httptest_files) bytes.gox fmt.gox http.gox \
+               net.gox os.gox
+       $(BUILDPACKAGE)
+http/httptest/check: $(CHECK_DEPS)
+       @$(MKDIR_P) http/httptest
+       $(CHECK)
+.PHONY: http/httptest/check
+
 http/pprof.lo: $(go_http_pprof_files) bufio.gox fmt.gox http.gox os.gox \
                runtime.gox runtime/pprof.gox strconv.gox strings.gox
        $(BUILDPACKAGE)
@@ -2398,8 +2461,8 @@ index/suffixarray/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: index/suffixarray/check
 
-io/ioutil.lo: $(go_io_ioutil_files) bytes.gox io.gox os.gox sort.gox \
-               strconv.gox
+io/ioutil.lo: $(go_io_ioutil_files) bytes.gox io.gox os.gox path/filepath.gox \
+               sort.gox strconv.gox
        $(BUILDPACKAGE)
 io/ioutil/check: $(CHECK_DEPS)
        @$(MKDIR_P) io/ioutil
@@ -2407,7 +2470,7 @@ io/ioutil/check: $(CHECK_DEPS)
 .PHONY: io/ioutil/check
 
 mime/multipart.lo: $(go_mime_multipart_files) bufio.gox bytes.gox io.gox \
-               mime.gox os.gox regexp.gox strings.gox
+               mime.gox net/textproto.gox os.gox regexp.gox strings.gox
        $(BUILDPACKAGE)
 mime/multipart/check: $(CHECK_DEPS)
        @$(MKDIR_P) mime/multipart
@@ -2445,6 +2508,14 @@ unix.go: $(srcdir)/go/os/signal/mkunix.sh sysinfo.go
        $(SHELL) $(srcdir)/go/os/signal/mkunix.sh sysinfo.go > $@.tmp
        mv -f $@.tmp $@
 
+path/filepath.lo: $(go_path_filepath_files) bytes.gox os.gox sort.gox \
+               strings.gox utf8.gox
+       $(BUILDPACKAGE)
+path/filepath/check: $(CHECK_DEPS)
+       @$(MKDIR_P) path/filepath
+       $(CHECK)
+.PHONY: path/filepath/check
+
 rpc/jsonrpc.lo: $(go_rpc_jsonrpc_files) fmt.gox io.gox json.gox net.gox \
                os.gox rpc.gox sync.gox
        $(BUILDPACKAGE)
@@ -2462,7 +2533,7 @@ runtime/debug/check: $(CHECK_DEPS)
 .PHONY: runtime/debug/check
 
 runtime/pprof.lo: $(go_runtime_pprof_files) bufio.gox fmt.gox io.gox os.gox \
-               runtime.gox
+               runtime.gox sync.gox
        $(BUILDPACKAGE)
 runtime/pprof/check: $(CHECK_DEPS)
        @$(MKDIR_P) runtime/pprof
@@ -2653,6 +2724,8 @@ crypto/cipher.gox: crypto/cipher.lo
        $(BUILDGOX)
 crypto/dsa.gox: crypto/dsa.lo
        $(BUILDGOX)
+crypto/ecdsa.gox: crypto/ecdsa.lo      
+       $(BUILDGOX)
 crypto/elliptic.gox: crypto/elliptic.lo
        $(BUILDGOX)
 crypto/hmac.gox: crypto/hmac.lo
@@ -2757,7 +2830,13 @@ hash/crc32.gox: hash/crc32.lo
        $(BUILDGOX)
 hash/crc64.gox: hash/crc64.lo
        $(BUILDGOX)
+hash/fnv.gox: hash/fnv.lo
+       $(BUILDGOX)
 
+http/cgi.gox: http/cgi.lo
+       $(BUILDGOX)
+http/httptest.gox: http/httptest.lo
+       $(BUILDGOX)
 http/pprof.gox: http/pprof.lo
        $(BUILDGOX)
 
@@ -2785,6 +2864,9 @@ os/inotify.gox: os/inotify.lo
 os/signal.gox: os/signal.lo
        $(BUILDGOX)
 
+path/filepath.gox: path/filepath.lo
+       $(BUILDGOX)
+
 rpc/jsonrpc.gox: rpc/jsonrpc.lo
        $(BUILDGOX)
 
@@ -2823,7 +2905,7 @@ TEST_PACKAGES = \
        fmt/check \
        gob/check \
        html/check \
-       $(if $(GCCGO_RUN_ALL_TESTS),http/check) \
+       http/check \
        io/check \
        json/check \
        log/check \
@@ -2872,6 +2954,7 @@ TEST_PACKAGES = \
        crypto/cast5/check \
        crypto/cipher/check \
        crypto/dsa/check \
+       crypto/ecdsa/check \
        crypto/elliptic/check \
        crypto/hmac/check \
        crypto/md4/check \
@@ -2916,6 +2999,8 @@ TEST_PACKAGES = \
        hash/adler32/check \
        hash/crc32/check \
        hash/crc64/check \
+       hash/fnv/check \
+       http/cgi/check \
        image/png/check \
        index/suffixarray/check \
        io/ioutil/check \
@@ -2923,6 +3008,7 @@ TEST_PACKAGES = \
        net/textproto/check \
        $(os_inotify_check) \
        os/signal/check \
+       path/filepath/check \
        rpc/jsonrpc/check \
        sync/atomic/check \
        testing/quick/check \
index 7bb302da9134d2fee613aba8aba7abf7d468a6ee..dd942254dcb7e53fff1d4d0de2755ed9c516ab4d 100644 (file)
@@ -110,6 +110,7 @@ am__installdirs = "$(DESTDIR)$(toolexeclibdir)" \
        "$(DESTDIR)$(toolexeclibgomimedir)" \
        "$(DESTDIR)$(toolexeclibgonetdir)" \
        "$(DESTDIR)$(toolexeclibgoosdir)" \
+       "$(DESTDIR)$(toolexeclibgopathdir)" \
        "$(DESTDIR)$(toolexeclibgorpcdir)" \
        "$(DESTDIR)$(toolexeclibgoruntimedir)" \
        "$(DESTDIR)$(toolexeclibgosyncdir)" \
@@ -141,9 +142,10 @@ am__DEPENDENCIES_2 = asn1/asn1.lo big/big.lo bufio/bufio.lo \
        container/heap.lo container/list.lo container/ring.lo \
        container/vector.lo crypto/aes.lo crypto/block.lo \
        crypto/blowfish.lo crypto/cast5.lo crypto/cipher.lo \
-       crypto/dsa.lo crypto/elliptic.lo crypto/hmac.lo crypto/md4.lo \
-       crypto/md5.lo crypto/ocsp.lo crypto/openpgp.lo crypto/rand.lo \
-       crypto/rc4.lo crypto/ripemd160.lo crypto/rsa.lo crypto/sha1.lo \
+       crypto/dsa.lo crypto/ecdsa.lo crypto/elliptic.lo \
+       crypto/hmac.lo crypto/md4.lo crypto/md5.lo crypto/ocsp.lo \
+       crypto/openpgp.lo crypto/rand.lo crypto/rc4.lo \
+       crypto/ripemd160.lo crypto/rsa.lo crypto/sha1.lo \
        crypto/sha256.lo crypto/sha512.lo crypto/subtle.lo \
        crypto/tls.lo crypto/twofish.lo crypto/x509.lo crypto/xtea.lo \
        crypto/openpgp/armor.lo crypto/openpgp/error.lo \
@@ -155,13 +157,14 @@ am__DEPENDENCIES_2 = asn1/asn1.lo big/big.lo bufio/bufio.lo \
        exp/datafmt.lo exp/draw.lo exp/eval.lo go/ast.lo go/doc.lo \
        go/parser.lo go/printer.lo go/scanner.lo go/token.lo \
        go/typechecker.lo hash/adler32.lo hash/crc32.lo hash/crc64.lo \
-       http/pprof.lo image/jpeg.lo image/png.lo index/suffixarray.lo \
-       io/ioutil.lo mime/multipart.lo net/dict.lo net/textproto.lo \
-       $(am__DEPENDENCIES_1) os/signal.lo rpc/jsonrpc.lo \
-       runtime/debug.lo runtime/pprof.lo sync/atomic.lo \
-       sync/atomic_c.lo syscalls/syscall.lo syscalls/errno.lo \
-       testing/testing.lo testing/iotest.lo testing/quick.lo \
-       testing/script.lo
+       hash/fnv.lo http/cgi.lo http/httptest.lo http/pprof.lo \
+       image/jpeg.lo image/png.lo index/suffixarray.lo io/ioutil.lo \
+       mime/multipart.lo net/dict.lo net/textproto.lo \
+       $(am__DEPENDENCIES_1) os/signal.lo path/filepath.lo \
+       rpc/jsonrpc.lo runtime/debug.lo runtime/pprof.lo \
+       sync/atomic.lo sync/atomic_c.lo syscalls/syscall.lo \
+       syscalls/errno.lo testing/testing.lo testing/iotest.lo \
+       testing/quick.lo testing/script.lo
 libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \
        $(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \
        $(am__DEPENDENCIES_1)
@@ -280,8 +283,9 @@ DATA = $(toolexeclibgo_DATA) $(toolexeclibgoarchive_DATA) \
        $(toolexeclibgoimage_DATA) $(toolexeclibgoindex_DATA) \
        $(toolexeclibgoio_DATA) $(toolexeclibgomime_DATA) \
        $(toolexeclibgonet_DATA) $(toolexeclibgoos_DATA) \
-       $(toolexeclibgorpc_DATA) $(toolexeclibgoruntime_DATA) \
-       $(toolexeclibgosync_DATA) $(toolexeclibgotesting_DATA)
+       $(toolexeclibgopath_DATA) $(toolexeclibgorpc_DATA) \
+       $(toolexeclibgoruntime_DATA) $(toolexeclibgosync_DATA) \
+       $(toolexeclibgotesting_DATA)
 RECURSIVE_CLEAN_TARGETS = mostlyclean-recursive clean-recursive        \
   distclean-recursive maintainer-clean-recursive
 AM_RECURSIVE_TARGETS = $(RECURSIVE_TARGETS:-recursive=) \
@@ -620,6 +624,7 @@ toolexeclibgocrypto_DATA = \
        crypto/cast5.gox \
        crypto/cipher.gox \
        crypto/dsa.gox \
+       crypto/ecdsa.gox \
        crypto/elliptic.gox \
        crypto/hmac.gox \
        crypto/md4.gox \
@@ -686,10 +691,13 @@ toolexeclibgohashdir = $(toolexeclibgodir)/hash
 toolexeclibgohash_DATA = \
        hash/adler32.gox \
        hash/crc32.gox \
-       hash/crc64.gox
+       hash/crc64.gox \
+       hash/fnv.gox
 
 toolexeclibgohttpdir = $(toolexeclibgodir)/http
 toolexeclibgohttp_DATA = \
+       http/cgi.gox \
+       http/httptest.gox \
        http/pprof.gox
 
 toolexeclibgoimagedir = $(toolexeclibgodir)/image
@@ -723,6 +731,10 @@ toolexeclibgoos_DATA = \
        $(os_inotify_gox) \
        os/signal.gox
 
+toolexeclibgopathdir = $(toolexeclibgodir)/path
+toolexeclibgopath_DATA = \
+       path/filepath.gox
+
 toolexeclibgorpcdir = $(toolexeclibgodir)/rpc
 toolexeclibgorpc_DATA = \
        rpc/jsonrpc.gox
@@ -928,6 +940,7 @@ go_html_files = \
 go_http_files = \
        go/http/chunked.go \
        go/http/client.go \
+       go/http/cookie.go \
        go/http/dump.go \
        go/http/fs.go \
        go/http/header.go \
@@ -1084,8 +1097,7 @@ go_patch_files = \
 
 go_path_files = \
        go/path/match.go \
-       go/path/path.go \
-       go/path/path_unix.go
+       go/path/path.go
 
 go_rand_files = \
        go/rand/exp.go \
@@ -1111,6 +1123,7 @@ go_runtime_files = \
        go/runtime/debug.go \
        go/runtime/error.go \
        go/runtime/extern.go \
+       go/runtime/mem.go \
        go/runtime/sig.go \
        go/runtime/softfloat64.go \
        go/runtime/type.go \
@@ -1170,6 +1183,7 @@ go_testing_files = \
 go_time_files = \
        go/time/format.go \
        go/time/sleep.go \
+       go/time/sys.go \
        go/time/tick.go \
        go/time/time.go \
        go/time/zoneinfo_unix.go
@@ -1286,6 +1300,9 @@ go_crypto_cipher_files = \
 go_crypto_dsa_files = \
        go/crypto/dsa/dsa.go
 
+go_crypto_ecdsa_files = \
+       go/crypto/ecdsa/ecdsa.go
+
 go_crypto_elliptic_files = \
        go/crypto/elliptic/elliptic.go
 
@@ -1490,6 +1507,7 @@ go_go_token_files = \
 
 go_go_typechecker_files = \
        go/go/typechecker/scope.go \
+       go/go/typechecker/type.go \
        go/go/typechecker/typechecker.go \
        go/go/typechecker/universe.go
 
@@ -1502,6 +1520,17 @@ go_hash_crc32_files = \
 go_hash_crc64_files = \
        go/hash/crc64/crc64.go
 
+go_hash_fnv_files = \
+       go/hash/fnv/fnv.go
+
+go_http_cgi_files = \
+       go/http/cgi/child.go \
+       go/http/cgi/host.go
+
+go_http_httptest_files = \
+       go/http/httptest/recorder.go \
+       go/http/httptest/server.go
+
 go_http_pprof_files = \
        go/http/pprof/pprof.go
 
@@ -1542,6 +1571,11 @@ go_os_signal_files = \
        go/os/signal/signal.go \
        unix.go
 
+go_path_filepath_files = \
+       go/path/filepath/match.go \
+       go/path/filepath/path.go \
+       go/path/filepath/path_unix.go
+
 go_rpc_jsonrpc_files = \
        go/rpc/jsonrpc/client.go \
        go/rpc/jsonrpc/server.go
@@ -1718,6 +1752,7 @@ libgo_go_objs = \
        crypto/cast5.lo \
        crypto/cipher.lo \
        crypto/dsa.lo \
+       crypto/ecdsa.lo \
        crypto/elliptic.lo \
        crypto/hmac.lo \
        crypto/md4.lo \
@@ -1767,6 +1802,9 @@ libgo_go_objs = \
        hash/adler32.lo \
        hash/crc32.lo \
        hash/crc64.lo \
+       hash/fnv.lo \
+       http/cgi.lo \
+       http/httptest.lo \
        http/pprof.lo \
        image/jpeg.lo \
        image/png.lo \
@@ -1777,6 +1815,7 @@ libgo_go_objs = \
        net/textproto.lo \
        $(os_lib_inotify_lo) \
        os/signal.lo \
+       path/filepath.lo \
        rpc/jsonrpc.lo \
        runtime/debug.lo \
        runtime/pprof.lo \
@@ -1883,7 +1922,7 @@ TEST_PACKAGES = \
        fmt/check \
        gob/check \
        html/check \
-       $(if $(GCCGO_RUN_ALL_TESTS),http/check) \
+       http/check \
        io/check \
        json/check \
        log/check \
@@ -1932,6 +1971,7 @@ TEST_PACKAGES = \
        crypto/cast5/check \
        crypto/cipher/check \
        crypto/dsa/check \
+       crypto/ecdsa/check \
        crypto/elliptic/check \
        crypto/hmac/check \
        crypto/md4/check \
@@ -1976,6 +2016,8 @@ TEST_PACKAGES = \
        hash/adler32/check \
        hash/crc32/check \
        hash/crc64/check \
+       hash/fnv/check \
+       http/cgi/check \
        image/png/check \
        index/suffixarray/check \
        io/ioutil/check \
@@ -1983,6 +2025,7 @@ TEST_PACKAGES = \
        net/textproto/check \
        $(os_inotify_check) \
        os/signal/check \
+       path/filepath/check \
        rpc/jsonrpc/check \
        sync/atomic/check \
        testing/quick/check \
@@ -3271,6 +3314,26 @@ uninstall-toolexeclibgoosDATA:
        test -n "$$files" || exit 0; \
        echo " ( cd '$(DESTDIR)$(toolexeclibgoosdir)' && rm -f" $$files ")"; \
        cd "$(DESTDIR)$(toolexeclibgoosdir)" && rm -f $$files
+install-toolexeclibgopathDATA: $(toolexeclibgopath_DATA)
+       @$(NORMAL_INSTALL)
+       test -z "$(toolexeclibgopathdir)" || $(MKDIR_P) "$(DESTDIR)$(toolexeclibgopathdir)"
+       @list='$(toolexeclibgopath_DATA)'; test -n "$(toolexeclibgopathdir)" || list=; \
+       for p in $$list; do \
+         if test -f "$$p"; then d=; else d="$(srcdir)/"; fi; \
+         echo "$$d$$p"; \
+       done | $(am__base_list) | \
+       while read files; do \
+         echo " $(INSTALL_DATA) $$files '$(DESTDIR)$(toolexeclibgopathdir)'"; \
+         $(INSTALL_DATA) $$files "$(DESTDIR)$(toolexeclibgopathdir)" || exit $$?; \
+       done
+
+uninstall-toolexeclibgopathDATA:
+       @$(NORMAL_UNINSTALL)
+       @list='$(toolexeclibgopath_DATA)'; test -n "$(toolexeclibgopathdir)" || list=; \
+       files=`for p in $$list; do echo $$p; done | sed -e 's|^.*/||'`; \
+       test -n "$$files" || exit 0; \
+       echo " ( cd '$(DESTDIR)$(toolexeclibgopathdir)' && rm -f" $$files ")"; \
+       cd "$(DESTDIR)$(toolexeclibgopathdir)" && rm -f $$files
 install-toolexeclibgorpcDATA: $(toolexeclibgorpc_DATA)
        @$(NORMAL_INSTALL)
        test -z "$(toolexeclibgorpcdir)" || $(MKDIR_P) "$(DESTDIR)$(toolexeclibgorpcdir)"
@@ -3668,7 +3731,7 @@ all-am: Makefile $(LIBRARIES) $(LTLIBRARIES) all-multi $(DATA) \
                config.h
 installdirs: installdirs-recursive
 installdirs-am:
-       for dir in "$(DESTDIR)$(toolexeclibdir)" "$(DESTDIR)$(toolexeclibdir)" "$(DESTDIR)$(toolexeclibgodir)" "$(DESTDIR)$(toolexeclibgoarchivedir)" "$(DESTDIR)$(toolexeclibgocompressdir)" "$(DESTDIR)$(toolexeclibgocontainerdir)" "$(DESTDIR)$(toolexeclibgocryptodir)" "$(DESTDIR)$(toolexeclibgocryptoopenpgpdir)" "$(DESTDIR)$(toolexeclibgodebugdir)" "$(DESTDIR)$(toolexeclibgoencodingdir)" "$(DESTDIR)$(toolexeclibgoexpdir)" "$(DESTDIR)$(toolexeclibgogodir)" "$(DESTDIR)$(toolexeclibgohashdir)" "$(DESTDIR)$(toolexeclibgohttpdir)" "$(DESTDIR)$(toolexeclibgoimagedir)" "$(DESTDIR)$(toolexeclibgoindexdir)" "$(DESTDIR)$(toolexeclibgoiodir)" "$(DESTDIR)$(toolexeclibgomimedir)" "$(DESTDIR)$(toolexeclibgonetdir)" "$(DESTDIR)$(toolexeclibgoosdir)" "$(DESTDIR)$(toolexeclibgorpcdir)" "$(DESTDIR)$(toolexeclibgoruntimedir)" "$(DESTDIR)$(toolexeclibgosyncdir)" "$(DESTDIR)$(toolexeclibgotestingdir)"; do \
+       for dir in "$(DESTDIR)$(toolexeclibdir)" "$(DESTDIR)$(toolexeclibdir)" "$(DESTDIR)$(toolexeclibgodir)" "$(DESTDIR)$(toolexeclibgoarchivedir)" "$(DESTDIR)$(toolexeclibgocompressdir)" "$(DESTDIR)$(toolexeclibgocontainerdir)" "$(DESTDIR)$(toolexeclibgocryptodir)" "$(DESTDIR)$(toolexeclibgocryptoopenpgpdir)" "$(DESTDIR)$(toolexeclibgodebugdir)" "$(DESTDIR)$(toolexeclibgoencodingdir)" "$(DESTDIR)$(toolexeclibgoexpdir)" "$(DESTDIR)$(toolexeclibgogodir)" "$(DESTDIR)$(toolexeclibgohashdir)" "$(DESTDIR)$(toolexeclibgohttpdir)" "$(DESTDIR)$(toolexeclibgoimagedir)" "$(DESTDIR)$(toolexeclibgoindexdir)" "$(DESTDIR)$(toolexeclibgoiodir)" "$(DESTDIR)$(toolexeclibgomimedir)" "$(DESTDIR)$(toolexeclibgonetdir)" "$(DESTDIR)$(toolexeclibgoosdir)" "$(DESTDIR)$(toolexeclibgopathdir)" "$(DESTDIR)$(toolexeclibgorpcdir)" "$(DESTDIR)$(toolexeclibgoruntimedir)" "$(DESTDIR)$(toolexeclibgosyncdir)" "$(DESTDIR)$(toolexeclibgotestingdir)"; do \
          test -z "$$dir" || $(MKDIR_P) "$$dir"; \
        done
 install: install-recursive
@@ -3741,9 +3804,9 @@ install-exec-am: install-multi install-toolexeclibLIBRARIES \
        install-toolexeclibgohttpDATA install-toolexeclibgoimageDATA \
        install-toolexeclibgoindexDATA install-toolexeclibgoioDATA \
        install-toolexeclibgomimeDATA install-toolexeclibgonetDATA \
-       install-toolexeclibgoosDATA install-toolexeclibgorpcDATA \
-       install-toolexeclibgoruntimeDATA install-toolexeclibgosyncDATA \
-       install-toolexeclibgotestingDATA
+       install-toolexeclibgoosDATA install-toolexeclibgopathDATA \
+       install-toolexeclibgorpcDATA install-toolexeclibgoruntimeDATA \
+       install-toolexeclibgosyncDATA install-toolexeclibgotestingDATA
 
 install-html: install-html-recursive
 
@@ -3800,7 +3863,8 @@ uninstall-am: uninstall-toolexeclibLIBRARIES \
        uninstall-toolexeclibgoimageDATA \
        uninstall-toolexeclibgoindexDATA uninstall-toolexeclibgoioDATA \
        uninstall-toolexeclibgomimeDATA uninstall-toolexeclibgonetDATA \
-       uninstall-toolexeclibgoosDATA uninstall-toolexeclibgorpcDATA \
+       uninstall-toolexeclibgoosDATA uninstall-toolexeclibgopathDATA \
+       uninstall-toolexeclibgorpcDATA \
        uninstall-toolexeclibgoruntimeDATA \
        uninstall-toolexeclibgosyncDATA \
        uninstall-toolexeclibgotestingDATA
@@ -3836,15 +3900,15 @@ uninstall-am: uninstall-toolexeclibLIBRARIES \
        install-toolexeclibgohttpDATA install-toolexeclibgoimageDATA \
        install-toolexeclibgoindexDATA install-toolexeclibgoioDATA \
        install-toolexeclibgomimeDATA install-toolexeclibgonetDATA \
-       install-toolexeclibgoosDATA install-toolexeclibgorpcDATA \
-       install-toolexeclibgoruntimeDATA install-toolexeclibgosyncDATA \
-       install-toolexeclibgotestingDATA installcheck installcheck-am \
-       installdirs installdirs-am maintainer-clean \
-       maintainer-clean-generic maintainer-clean-multi mostlyclean \
-       mostlyclean-compile mostlyclean-generic mostlyclean-libtool \
-       mostlyclean-local mostlyclean-multi pdf pdf-am ps ps-am tags \
-       tags-recursive uninstall uninstall-am \
-       uninstall-toolexeclibLIBRARIES \
+       install-toolexeclibgoosDATA install-toolexeclibgopathDATA \
+       install-toolexeclibgorpcDATA install-toolexeclibgoruntimeDATA \
+       install-toolexeclibgosyncDATA install-toolexeclibgotestingDATA \
+       installcheck installcheck-am installdirs installdirs-am \
+       maintainer-clean maintainer-clean-generic \
+       maintainer-clean-multi mostlyclean mostlyclean-compile \
+       mostlyclean-generic mostlyclean-libtool mostlyclean-local \
+       mostlyclean-multi pdf pdf-am ps ps-am tags tags-recursive \
+       uninstall uninstall-am uninstall-toolexeclibLIBRARIES \
        uninstall-toolexeclibLTLIBRARIES uninstall-toolexeclibgoDATA \
        uninstall-toolexeclibgoarchiveDATA \
        uninstall-toolexeclibgocompressDATA \
@@ -3859,7 +3923,8 @@ uninstall-am: uninstall-toolexeclibLIBRARIES \
        uninstall-toolexeclibgoimageDATA \
        uninstall-toolexeclibgoindexDATA uninstall-toolexeclibgoioDATA \
        uninstall-toolexeclibgomimeDATA uninstall-toolexeclibgonetDATA \
-       uninstall-toolexeclibgoosDATA uninstall-toolexeclibgorpcDATA \
+       uninstall-toolexeclibgoosDATA uninstall-toolexeclibgopathDATA \
+       uninstall-toolexeclibgorpcDATA \
        uninstall-toolexeclibgoruntimeDATA \
        uninstall-toolexeclibgosyncDATA \
        uninstall-toolexeclibgotestingDATA
@@ -3918,7 +3983,7 @@ asn1/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: asn1/check
 
-big/big.lo: $(go_big_files) fmt.gox rand.gox strings.gox
+big/big.lo: $(go_big_files) fmt.gox rand.gox strings.gox os.gox
        $(BUILDPACKAGE)
 big/check: $(CHECK_DEPS)
        $(CHECK)
@@ -3983,9 +4048,9 @@ fmt/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: fmt/check
 
-gob/gob.lo: $(go_gob_files) bytes.gox fmt.gox io.gox math.gox os.gox \
-               reflect.gox runtime.gox strings.gox sync.gox unicode.gox \
-               utf8.gox
+gob/gob.lo: $(go_gob_files) bufio.gox bytes.gox fmt.gox io.gox math.gox \
+               os.gox reflect.gox runtime.gox strings.gox sync.gox \
+               unicode.gox utf8.gox
        $(BUILDPACKAGE)
 gob/check: $(CHECK_DEPS)
        $(CHECK)
@@ -4007,8 +4072,8 @@ html/check: $(CHECK_DEPS)
 http/http.lo: $(go_http_files) bufio.gox bytes.gox container/vector.gox \
                crypto/rand.gox crypto/tls.gox encoding/base64.gox fmt.gox \
                io.gox io/ioutil.gox log.gox mime.gox mime/multipart.gox \
-               net.gox net/textproto.gox os.gox path.gox sort.gox \
-               strconv.gox strings.gox sync.gox time.gox utf8.gox
+               net.gox net/textproto.gox os.gox path.gox path/filepath.gox \
+               sort.gox strconv.gox strings.gox sync.gox time.gox utf8.gox
        $(BUILDPACKAGE)
 http/check: $(CHECK_DEPS)
        $(CHECK)
@@ -4020,7 +4085,7 @@ image/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: image/check
 
-io/io.lo: $(go_io_files) os.gox runtime.gox sync.gox
+io/io.lo: $(go_io_files) os.gox sync.gox
        $(BUILDPACKAGE)
 io/check: $(CHECK_DEPS)
        $(CHECK)
@@ -4083,8 +4148,7 @@ patch/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: patch/check
 
-path/path.lo: $(go_path_files) io/ioutil.gox os.gox sort.gox strings.gox \
-               utf8.gox
+path/path.lo: $(go_path_files) os.gox strings.gox utf8.gox
        $(BUILDPACKAGE)
 path/check: $(CHECK_DEPS)
        $(CHECK)
@@ -4185,7 +4249,7 @@ template/check: $(CHECK_DEPS)
 .PHONY: template/check
 
 testing/testing.lo: $(go_testing_files) flag.gox fmt.gox os.gox regexp.gox \
-               runtime.gox time.gox
+               runtime.gox runtime/pprof.gox time.gox
        $(BUILDPACKAGE)
 testing/check: $(CHECK_DEPS)
        $(CHECK)
@@ -4248,7 +4312,7 @@ archive/tar/check: $(CHECK_DEPS)
 
 archive/zip.lo: $(go_archive_zip_files) bufio.gox bytes.gox \
                compress/flate.gox hash.gox hash/crc32.gox \
-               encoding/binary.gox io.gox os.gox
+               encoding/binary.gox io.gox io/ioutil.gox os.gox
        $(BUILDPACKAGE)
 archive/zip/check: $(CHECK_DEPS)
        @$(MKDIR_P) archive/zip
@@ -4363,6 +4427,14 @@ crypto/dsa/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: crypto/dsa/check
 
+crypto/ecdsa.lo: $(go_crypto_ecdsa_files) big.gox crypto/elliptic.gox io.gox \
+               os.gox
+       $(BUILDPACKAGE)
+crypto/ecdsa/check: $(CHECK_DEPS)
+       @$(MKDIR_P) crypto/ecdsa
+       $(CHECK)
+.PHONY: crypto/ecdsa/check
+
 crypto/elliptic.lo: $(go_crypto_elliptic_files) big.gox io.gox os.gox sync.gox
        $(BUILDPACKAGE)
 crypto/elliptic/check: $(CHECK_DEPS)
@@ -4400,8 +4472,8 @@ crypto/ocsp/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: crypto/ocsp/check
 
-crypto/openpgp.lo: $(go_crypto_openpgp_files) crypto.gox \
-                crypto/openpgp/armor.gox crypto/openpgp/error.gox \
+crypto/openpgp.lo: $(go_crypto_openpgp_files) crypto.gox crypto/dsa.gox \
+               crypto/openpgp/armor.gox crypto/openpgp/error.gox \
                crypto/openpgp/packet.gox crypto/rsa.gox crypto/sha256.gox \
                hash.gox io.gox os.gox strconv.gox time.gox
        $(BUILDPACKAGE)
@@ -4523,10 +4595,10 @@ crypto/openpgp/error/check: $(CHECK_DEPS)
 crypto/openpgp/packet.lo: $(go_crypto_openpgp_packet_files) big.gox bytes.gox \
                compress/flate.gox compress/zlib.gox crypto.gox \
                crypto/aes.gox crypto/cast5.gox crypto/cipher.gox \
-               crypto/openpgp/error.gox crypto/openpgp/s2k.gox \
-               crypto/rand.gox crypto/rsa.gox crypto/sha1.gox \
-               crypto/subtle.gox encoding/binary.gox hash.gox io.gox \
-               io/ioutil.gox os.gox strconv.gox strings.gox
+               crypto/dsa.gox crypto/openpgp/error.gox \
+               crypto/openpgp/s2k.gox crypto/rand.gox crypto/rsa.gox \
+               crypto/sha1.gox crypto/subtle.gox encoding/binary.gox fmt.gox \
+               hash.gox io.gox io/ioutil.gox os.gox strconv.gox strings.gox
        $(BUILDPACKAGE)
 crypto/openpgp/packet/check: $(CHECK_DEPS)
        @$(MKDIR_P) crypto/openpgp/packet
@@ -4674,8 +4746,8 @@ exp/eval/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: exp/eval/check
 
-go/ast.lo: $(go_go_ast_files) fmt.gox go/token.gox io.gox os.gox reflect.gox \
-               unicode.gox utf8.gox
+go/ast.lo: $(go_go_ast_files) bytes.gox fmt.gox go/token.gox io.gox os.gox \
+               reflect.gox unicode.gox utf8.gox
        $(BUILDPACKAGE)
 go/ast/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/ast
@@ -4692,7 +4764,7 @@ go/doc/check: $(CHECK_DEPS)
 
 go/parser.lo: $(go_go_parser_files) bytes.gox fmt.gox go/ast.gox \
                go/scanner.gox go/token.gox io.gox io/ioutil.gox os.gox \
-               path.gox strings.gox
+               path/filepath.gox strings.gox
        $(BUILDPACKAGE)
 go/parser/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/parser
@@ -4700,8 +4772,8 @@ go/parser/check: $(CHECK_DEPS)
 .PHONY: go/parser/check
 
 go/printer.lo: $(go_go_printer_files) bytes.gox fmt.gox go/ast.gox \
-               go/token.gox io.gox os.gox reflect.gox runtime.gox \
-               strings.gox tabwriter.gox
+               go/token.gox io.gox os.gox path/filepath.gox reflect.gox \
+               runtime.gox strings.gox tabwriter.gox
        $(BUILDPACKAGE)
 go/printer/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/printer
@@ -4709,8 +4781,8 @@ go/printer/check: $(CHECK_DEPS)
 .PHONY: go/printer/check
 
 go/scanner.lo: $(go_go_scanner_files) bytes.gox container/vector.gox fmt.gox \
-               go/token.gox io.gox os.gox path.gox sort.gox strconv.gox \
-               unicode.gox utf8.gox
+               go/token.gox io.gox os.gox path/filepath.gox sort.gox \
+               strconv.gox unicode.gox utf8.gox
        $(BUILDPACKAGE)
 go/scanner/check: $(CHECK_DEPS)
        @$(MKDIR_P) go/scanner
@@ -4753,6 +4825,30 @@ hash/crc64/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: hash/crc64/check
 
+hash/fnv.lo: $(go_hash_fnv_files) encoding/binary.gox hash.gox os.gox
+       $(BUILDPACKAGE)
+hash/fnv/check: $(CHECK_DEPS)
+       @$(MKDIR_P) hash/fnv
+       $(CHECK)
+.PHONY: hash/fnv/check
+
+http/cgi.lo: $(go_http_cgi_files) bufio.gox bytes.gox encoding/line.gox \
+               exec.gox fmt.gox http.gox io.gox io/ioutil.gox log.gox \
+               os.gox path/filepath.gox regexp.gox strconv.gox strings.gox
+       $(BUILDPACKAGE)
+http/cgi/check: $(CHECK_DEPS)
+       @$(MKDIR_P) http/cgi
+       $(CHECK)
+.PHONY: http/cgi/check
+
+http/httptest.lo: $(go_http_httptest_files) bytes.gox fmt.gox http.gox \
+               net.gox os.gox
+       $(BUILDPACKAGE)
+http/httptest/check: $(CHECK_DEPS)
+       @$(MKDIR_P) http/httptest
+       $(CHECK)
+.PHONY: http/httptest/check
+
 http/pprof.lo: $(go_http_pprof_files) bufio.gox fmt.gox http.gox os.gox \
                runtime.gox runtime/pprof.gox strconv.gox strings.gox
        $(BUILDPACKAGE)
@@ -4784,8 +4880,8 @@ index/suffixarray/check: $(CHECK_DEPS)
        $(CHECK)
 .PHONY: index/suffixarray/check
 
-io/ioutil.lo: $(go_io_ioutil_files) bytes.gox io.gox os.gox sort.gox \
-               strconv.gox
+io/ioutil.lo: $(go_io_ioutil_files) bytes.gox io.gox os.gox path/filepath.gox \
+               sort.gox strconv.gox
        $(BUILDPACKAGE)
 io/ioutil/check: $(CHECK_DEPS)
        @$(MKDIR_P) io/ioutil
@@ -4793,7 +4889,7 @@ io/ioutil/check: $(CHECK_DEPS)
 .PHONY: io/ioutil/check
 
 mime/multipart.lo: $(go_mime_multipart_files) bufio.gox bytes.gox io.gox \
-               mime.gox os.gox regexp.gox strings.gox
+               mime.gox net/textproto.gox os.gox regexp.gox strings.gox
        $(BUILDPACKAGE)
 mime/multipart/check: $(CHECK_DEPS)
        @$(MKDIR_P) mime/multipart
@@ -4831,6 +4927,14 @@ unix.go: $(srcdir)/go/os/signal/mkunix.sh sysinfo.go
        $(SHELL) $(srcdir)/go/os/signal/mkunix.sh sysinfo.go > $@.tmp
        mv -f $@.tmp $@
 
+path/filepath.lo: $(go_path_filepath_files) bytes.gox os.gox sort.gox \
+               strings.gox utf8.gox
+       $(BUILDPACKAGE)
+path/filepath/check: $(CHECK_DEPS)
+       @$(MKDIR_P) path/filepath
+       $(CHECK)
+.PHONY: path/filepath/check
+
 rpc/jsonrpc.lo: $(go_rpc_jsonrpc_files) fmt.gox io.gox json.gox net.gox \
                os.gox rpc.gox sync.gox
        $(BUILDPACKAGE)
@@ -4848,7 +4952,7 @@ runtime/debug/check: $(CHECK_DEPS)
 .PHONY: runtime/debug/check
 
 runtime/pprof.lo: $(go_runtime_pprof_files) bufio.gox fmt.gox io.gox os.gox \
-               runtime.gox
+               runtime.gox sync.gox
        $(BUILDPACKAGE)
 runtime/pprof/check: $(CHECK_DEPS)
        @$(MKDIR_P) runtime/pprof
@@ -5034,6 +5138,8 @@ crypto/cipher.gox: crypto/cipher.lo
        $(BUILDGOX)
 crypto/dsa.gox: crypto/dsa.lo
        $(BUILDGOX)
+crypto/ecdsa.gox: crypto/ecdsa.lo      
+       $(BUILDGOX)
 crypto/elliptic.gox: crypto/elliptic.lo
        $(BUILDGOX)
 crypto/hmac.gox: crypto/hmac.lo
@@ -5138,7 +5244,13 @@ hash/crc32.gox: hash/crc32.lo
        $(BUILDGOX)
 hash/crc64.gox: hash/crc64.lo
        $(BUILDGOX)
+hash/fnv.gox: hash/fnv.lo
+       $(BUILDGOX)
 
+http/cgi.gox: http/cgi.lo
+       $(BUILDGOX)
+http/httptest.gox: http/httptest.lo
+       $(BUILDGOX)
 http/pprof.gox: http/pprof.lo
        $(BUILDGOX)
 
@@ -5166,6 +5278,9 @@ os/inotify.gox: os/inotify.lo
 os/signal.gox: os/signal.lo
        $(BUILDGOX)
 
+path/filepath.gox: path/filepath.lo
+       $(BUILDGOX)
+
 rpc/jsonrpc.gox: rpc/jsonrpc.lo
        $(BUILDGOX)
 
index d8d9bba60bc56e977c43213a236af5732360d0e4..3b265c9b72e6fe99cf50fb88ee13fabd44400a77 100644 (file)
@@ -19,6 +19,7 @@ import (
        "hash/crc32"
        "encoding/binary"
        "io"
+       "io/ioutil"
        "os"
 )
 
@@ -109,7 +110,7 @@ func (f *File) Open() (rc io.ReadCloser, err os.Error) {
        r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size)
        switch f.Method {
        case 0: // store (no compression)
-               rc = nopCloser{r}
+               rc = ioutil.NopCloser(r)
        case 8: // DEFLATE
                rc = flate.NewReader(r)
        default:
@@ -147,12 +148,6 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
 
 func (r *checksumReader) Close() os.Error { return r.rc.Close() }
 
-type nopCloser struct {
-       io.Reader
-}
-
-func (f nopCloser) Close() os.Error { return nil }
-
 func readFileHeader(f *File, r io.Reader) (err os.Error) {
        defer func() {
                if rerr, ok := recover().(os.Error); ok {
index 46e0087343a9a0bce422cd0a637b8e4d52efecd8..ecd70e03ef10b954c6a011d5c6408839d24e9f7b 100644 (file)
@@ -8,6 +8,7 @@ package big
 
 import (
        "fmt"
+       "os"
        "rand"
 )
 
@@ -393,62 +394,19 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
 }
 
 
-// SetBytes interprets b as the bytes of a big-endian, unsigned integer and
-// sets z to that value.
-func (z *Int) SetBytes(b []byte) *Int {
-       const s = _S
-       z.abs = z.abs.make((len(b) + s - 1) / s)
-
-       j := 0
-       for len(b) >= s {
-               var w Word
-
-               for i := s; i > 0; i-- {
-                       w <<= 8
-                       w |= Word(b[len(b)-i])
-               }
-
-               z.abs[j] = w
-               j++
-               b = b[0 : len(b)-s]
-       }
-
-       if len(b) > 0 {
-               var w Word
-
-               for i := len(b); i > 0; i-- {
-                       w <<= 8
-                       w |= Word(b[len(b)-i])
-               }
-
-               z.abs[j] = w
-       }
-
-       z.abs = z.abs.norm()
+// SetBytes interprets buf as the bytes of a big-endian unsigned
+// integer, sets z to that value, and returns z.
+func (z *Int) SetBytes(buf []byte) *Int {
+       z.abs = z.abs.setBytes(buf)
        z.neg = false
        return z
 }
 
 
-// Bytes returns the absolute value of x as a big-endian byte array.
+// Bytes returns the absolute value of z as a big-endian byte slice.
 func (z *Int) Bytes() []byte {
-       const s = _S
-       b := make([]byte, len(z.abs)*s)
-
-       for i, w := range z.abs {
-               wordBytes := b[(len(z.abs)-i-1)*s : (len(z.abs)-i)*s]
-               for j := s - 1; j >= 0; j-- {
-                       wordBytes[j] = byte(w)
-                       w >>= 8
-               }
-       }
-
-       i := 0
-       for i < len(b) && b[i] == 0 {
-               i++
-       }
-
-       return b[i:]
+       buf := make([]byte, len(z.abs)*_S)
+       return buf[z.abs.bytes(buf):]
 }
 
 
@@ -739,3 +697,34 @@ func (z *Int) Not(x *Int) *Int {
        z.neg = true // z cannot be zero if x is positive
        return z
 }
+
+
+// Gob codec version. Permits backward-compatible changes to the encoding.
+const version byte = 1
+
+// GobEncode implements the gob.GobEncoder interface.
+func (z *Int) GobEncode() ([]byte, os.Error) {
+       buf := make([]byte, len(z.abs)*_S+1) // extra byte for version and sign bit
+       i := z.abs.bytes(buf) - 1            // i >= 0
+       b := version << 1                    // make space for sign bit
+       if z.neg {
+               b |= 1
+       }
+       buf[i] = b
+       return buf[i:], nil
+}
+
+
+// GobDecode implements the gob.GobDecoder interface.
+func (z *Int) GobDecode(buf []byte) os.Error {
+       if len(buf) == 0 {
+               return os.NewError("Int.GobDecode: no data")
+       }
+       b := buf[0]
+       if b>>1 != version {
+               return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1))
+       }
+       z.neg = b&1 != 0
+       z.abs = z.abs.setBytes(buf[1:])
+       return nil
+}
index fc981e1da46c4fab2ff04e0d61a3e3082f8539a4..c0cc9accf1a86103b6596d40dd35eda95b41aaad 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bytes"
        "encoding/hex"
        "fmt"
+       "gob"
        "testing"
        "testing/quick"
 )
@@ -1053,3 +1054,41 @@ func TestModInverse(t *testing.T) {
                }
        }
 }
+
+
+var gobEncodingTests = []string{
+       "0",
+       "1",
+       "2",
+       "10",
+       "42",
+       "1234567890",
+       "298472983472983471903246121093472394872319615612417471234712061",
+}
+
+func TestGobEncoding(t *testing.T) {
+       var medium bytes.Buffer
+       enc := gob.NewEncoder(&medium)
+       dec := gob.NewDecoder(&medium)
+       for i, test := range gobEncodingTests {
+               for j := 0; j < 2; j++ {
+                       medium.Reset() // empty buffer for each test case (in case of failures)
+                       stest := test
+                       if j == 0 {
+                               stest = "-" + test
+                       }
+                       var tx Int
+                       tx.SetString(stest, 10)
+                       if err := enc.Encode(&tx); err != nil {
+                               t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
+                       }
+                       var rx Int
+                       if err := dec.Decode(&rx); err != nil {
+                               t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
+                       }
+                       if rx.Cmp(&tx) != 0 {
+                               t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
+                       }
+               }
+       }
+}
index a308f69e8cd7ea8605128f5d69f0f42bfad6f61f..a04d3b1d9c1e4cc21552d72228fac0246a98fe76 100644 (file)
@@ -1065,3 +1065,50 @@ NextRandom:
 
        return true
 }
+
+
+// bytes writes the value of z into buf using big-endian encoding.
+// len(buf) must be >= len(z)*_S. The value of z is encoded in the
+// slice buf[i:]. The number i of unused bytes at the beginning of
+// buf is returned as result.
+func (z nat) bytes(buf []byte) (i int) {
+       i = len(buf)
+       for _, d := range z {
+               for j := 0; j < _S; j++ {
+                       i--
+                       buf[i] = byte(d)
+                       d >>= 8
+               }
+       }
+
+       for i < len(buf) && buf[i] == 0 {
+               i++
+       }
+
+       return
+}
+
+
+// setBytes interprets buf as the bytes of a big-endian unsigned
+// integer, sets z to that value, and returns z.
+func (z nat) setBytes(buf []byte) nat {
+       z = z.make((len(buf) + _S - 1) / _S)
+
+       k := 0
+       s := uint(0)
+       var d Word
+       for i := len(buf); i > 0; i-- {
+               d |= Word(buf[i-1]) << s
+               if s += 8; s == _S*8 {
+                       z[k] = d
+                       k++
+                       s = 0
+                       d = 0
+               }
+       }
+       if k < len(z) {
+               z[k] = d
+       }
+
+       return z.norm()
+}
index 059ca6dd223d91ea2a6d42bd989ccb9e1db580e3..8028e04dcd972dc6b51828afab7e36dcc9138a5c 100644 (file)
@@ -2,9 +2,10 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package bufio
+package bufio_test
 
 import (
+       . "bufio"
        "bytes"
        "fmt"
        "io"
@@ -502,9 +503,8 @@ func TestWriteString(t *testing.T) {
        b.WriteString("7890")                      // easy after flush
        b.WriteString("abcdefghijklmnopqrstuvwxy") // hard
        b.WriteString("z")
-       b.Flush()
-       if b.err != nil {
-               t.Error("WriteString", b.err)
+       if err := b.Flush(); err != nil {
+               t.Error("WriteString", err)
        }
        s := "01234567890abcdefghijklmnopqrstuvwxyz"
        if string(buf.Bytes()) != s {
index ff54164b2cce81244ca6887626e067ca795ae25f..ed5884a4b78de8a1437bedd82a69b59c97da4b20 100644 (file)
@@ -191,9 +191,16 @@ func testSync(t *testing.T, level int, input []byte, name string) {
                        t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
                        return
                }
-               if i == 0 && buf.buf.Len() != 0 {
-                       t.Errorf("testSync/%d (%d, %d, %s): extra data after %d", i, level, len(input), name, hi-lo)
-               }
+               // This test originally checked that after reading
+               // the first half of the input, there was nothing left
+               // in the read buffer (buf.buf.Len() != 0) but that is
+               // not necessarily the case: the write Flush may emit
+               // some extra framing bits that are not necessary
+               // to process to obtain the first half of the uncompressed
+               // data.  The test ran correctly most of the time, because
+               // the background goroutine had usually read even
+               // those extra bits by now, but it's not a useful thing to
+               // check.
                buf.WriteMode()
        }
        buf.ReadMode()
index 7795a4c14896486f07129ebea00addf3a9c0a27a..4b5dfaadea2023c51568f4109615decb22a2c683 100644 (file)
@@ -9,6 +9,7 @@ import (
        "io"
        "io/ioutil"
        "os"
+       "runtime"
        "strconv"
        "strings"
        "testing"
@@ -117,16 +118,34 @@ func (devNull) Write(p []byte) (int, os.Error) {
        return len(p), nil
 }
 
-func BenchmarkDecoder(b *testing.B) {
+func benchmarkDecoder(b *testing.B, n int) {
        b.StopTimer()
+       b.SetBytes(int64(n))
        buf0, _ := ioutil.ReadFile("../testdata/e.txt")
+       buf0 = buf0[:10000]
        compressed := bytes.NewBuffer(nil)
        w := NewWriter(compressed, LSB, 8)
-       io.Copy(w, bytes.NewBuffer(buf0))
+       for i := 0; i < n; i += len(buf0) {
+               io.Copy(w, bytes.NewBuffer(buf0))
+       }
        w.Close()
        buf1 := compressed.Bytes()
+       buf0, compressed, w = nil, nil, nil
+       runtime.GC()
        b.StartTimer()
        for i := 0; i < b.N; i++ {
                io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8))
        }
 }
+
+func BenchmarkDecoder1e4(b *testing.B) {
+       benchmarkDecoder(b, 1e4)
+}
+
+func BenchmarkDecoder1e5(b *testing.B) {
+       benchmarkDecoder(b, 1e5)
+}
+
+func BenchmarkDecoder1e6(b *testing.B) {
+       benchmarkDecoder(b, 1e6)
+}
index 715b974aa1ec20ab11fcb65ef1e95a51e938d2ab..2e0a8de0a87f84124d3c754b8b7831177b7de479 100644 (file)
@@ -8,6 +8,7 @@ import (
        "io"
        "io/ioutil"
        "os"
+       "runtime"
        "testing"
 )
 
@@ -99,13 +100,33 @@ func TestWriter(t *testing.T) {
        }
 }
 
-func BenchmarkEncoder(b *testing.B) {
+func benchmarkEncoder(b *testing.B, n int) {
        b.StopTimer()
-       buf, _ := ioutil.ReadFile("../testdata/e.txt")
+       b.SetBytes(int64(n))
+       buf0, _ := ioutil.ReadFile("../testdata/e.txt")
+       buf0 = buf0[:10000]
+       buf1 := make([]byte, n)
+       for i := 0; i < n; i += len(buf0) {
+               copy(buf1[i:], buf0)
+       }
+       buf0 = nil
+       runtime.GC()
        b.StartTimer()
        for i := 0; i < b.N; i++ {
                w := NewWriter(devNull{}, LSB, 8)
-               w.Write(buf)
+               w.Write(buf1)
                w.Close()
        }
 }
+
+func BenchmarkEncoder1e4(b *testing.B) {
+       benchmarkEncoder(b, 1e4)
+}
+
+func BenchmarkEncoder1e5(b *testing.B) {
+       benchmarkEncoder(b, 1e5)
+}
+
+func BenchmarkEncoder1e6(b *testing.B) {
+       benchmarkEncoder(b, 1e6)
+}
diff --git a/libgo/go/crypto/ecdsa/ecdsa.go b/libgo/go/crypto/ecdsa/ecdsa.go
new file mode 100644 (file)
index 0000000..1f37849
--- /dev/null
@@ -0,0 +1,149 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package ecdsa implements the Elliptic Curve Digital Signature Algorithm, as
+// defined in FIPS 186-3.
+package ecdsa
+
+// References:
+//   [NSA]: Suite B implementor's guide to FIPS 186-3,
+//     http://www.nsa.gov/ia/_files/ecdsa.pdf
+//   [SECG]: SECG, SEC1
+//     http://www.secg.org/download/aid-780/sec1-v2.pdf
+
+import (
+       "big"
+       "crypto/elliptic"
+       "io"
+       "os"
+)
+
+// PublicKey represents an ECDSA public key.
+type PublicKey struct {
+       *elliptic.Curve
+       X, Y *big.Int
+}
+
+// PrivateKey represents a ECDSA private key.
+type PrivateKey struct {
+       PublicKey
+       D *big.Int
+}
+
+var one = new(big.Int).SetInt64(1)
+
+// randFieldElement returns a random element of the field underlying the given
+// curve using the procedure given in [NSA] A.2.1.
+func randFieldElement(c *elliptic.Curve, rand io.Reader) (k *big.Int, err os.Error) {
+       b := make([]byte, c.BitSize/8+8)
+       _, err = rand.Read(b)
+       if err != nil {
+               return
+       }
+
+       k = new(big.Int).SetBytes(b)
+       n := new(big.Int).Sub(c.N, one)
+       k.Mod(k, n)
+       k.Add(k, one)
+       return
+}
+
+// GenerateKey generates a public&private key pair.
+func GenerateKey(c *elliptic.Curve, rand io.Reader) (priv *PrivateKey, err os.Error) {
+       k, err := randFieldElement(c, rand)
+       if err != nil {
+               return
+       }
+
+       priv = new(PrivateKey)
+       priv.PublicKey.Curve = c
+       priv.D = k
+       priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
+       return
+}
+
+// hashToInt converts a hash value to an integer. There is some disagreement
+// about how this is done. [NSA] suggests that this is done in the obvious
+// manner, but [SECG] truncates the hash to the bit-length of the curve order
+// first. We follow [SECG] because that's what OpenSSL does.
+func hashToInt(hash []byte, c *elliptic.Curve) *big.Int {
+       orderBits := c.N.BitLen()
+       orderBytes := (orderBits + 7) / 8
+       if len(hash) > orderBytes {
+               hash = hash[:orderBytes]
+       }
+
+       ret := new(big.Int).SetBytes(hash)
+       excess := orderBytes*8 - orderBits
+       if excess > 0 {
+               ret.Rsh(ret, uint(excess))
+       }
+       return ret
+}
+
+// Sign signs an arbitrary length hash (which should be the result of hashing a
+// larger message) using the private key, priv. It returns the signature as a
+// pair of integers. The security of the private key depends on the entropy of
+// rand.
+func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err os.Error) {
+       // See [NSA] 3.4.1
+       c := priv.PublicKey.Curve
+
+       var k, kInv *big.Int
+       for {
+               for {
+                       k, err = randFieldElement(c, rand)
+                       if err != nil {
+                               r = nil
+                               return
+                       }
+
+                       kInv = new(big.Int).ModInverse(k, c.N)
+                       r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
+                       r.Mod(r, priv.Curve.N)
+                       if r.Sign() != 0 {
+                               break
+                       }
+               }
+
+               e := hashToInt(hash, c)
+               s = new(big.Int).Mul(priv.D, r)
+               s.Add(s, e)
+               s.Mul(s, kInv)
+               s.Mod(s, priv.PublicKey.Curve.N)
+               if s.Sign() != 0 {
+                       break
+               }
+       }
+
+       return
+}
+
+// Verify verifies the signature in r, s of hash using the public key, pub. It
+// returns true iff the signature is valid.
+func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
+       // See [NSA] 3.4.2
+       c := pub.Curve
+
+       if r.Sign() == 0 || s.Sign() == 0 {
+               return false
+       }
+       if r.Cmp(c.N) >= 0 || s.Cmp(c.N) >= 0 {
+               return false
+       }
+       e := hashToInt(hash, c)
+       w := new(big.Int).ModInverse(s, c.N)
+
+       u1 := e.Mul(e, w)
+       u2 := w.Mul(r, w)
+
+       x1, y1 := c.ScalarBaseMult(u1.Bytes())
+       x2, y2 := c.ScalarMult(pub.X, pub.Y, u2.Bytes())
+       if x1.Cmp(x2) == 0 {
+               return false
+       }
+       x, _ := c.Add(x1, y1, x2, y2)
+       x.Mod(x, c.N)
+       return x.Cmp(r) == 0
+}
diff --git a/libgo/go/crypto/ecdsa/ecdsa_test.go b/libgo/go/crypto/ecdsa/ecdsa_test.go
new file mode 100644 (file)
index 0000000..cc22b7a
--- /dev/null
@@ -0,0 +1,218 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ecdsa
+
+import (
+       "big"
+       "crypto/elliptic"
+       "crypto/sha1"
+       "crypto/rand"
+       "encoding/hex"
+       "testing"
+)
+
+func testKeyGeneration(t *testing.T, c *elliptic.Curve, tag string) {
+       priv, err := GenerateKey(c, rand.Reader)
+       if err != nil {
+               t.Errorf("%s: error: %s", tag, err)
+               return
+       }
+       if !c.IsOnCurve(priv.PublicKey.X, priv.PublicKey.Y) {
+               t.Errorf("%s: public key invalid", tag, err)
+       }
+}
+
+func TestKeyGeneration(t *testing.T) {
+       testKeyGeneration(t, elliptic.P224(), "p224")
+       testKeyGeneration(t, elliptic.P256(), "p256")
+       testKeyGeneration(t, elliptic.P384(), "p384")
+       testKeyGeneration(t, elliptic.P521(), "p521")
+}
+
+func testSignAndVerify(t *testing.T, c *elliptic.Curve, tag string) {
+       priv, _ := GenerateKey(c, rand.Reader)
+
+       hashed := []byte("testing")
+       r, s, err := Sign(rand.Reader, priv, hashed)
+       if err != nil {
+               t.Errorf("%s: error signing: %s", tag, err)
+               return
+       }
+
+       if !Verify(&priv.PublicKey, hashed, r, s) {
+               t.Errorf("%s: Verify failed", tag)
+       }
+
+       hashed[0] ^= 0xff
+       if Verify(&priv.PublicKey, hashed, r, s) {
+               t.Errorf("%s: Verify always works!", tag)
+       }
+}
+
+func TestSignAndVerify(t *testing.T) {
+       testSignAndVerify(t, elliptic.P224(), "p224")
+       testSignAndVerify(t, elliptic.P256(), "p256")
+       testSignAndVerify(t, elliptic.P384(), "p384")
+       testSignAndVerify(t, elliptic.P521(), "p521")
+}
+
+func fromHex(s string) *big.Int {
+       r, ok := new(big.Int).SetString(s, 16)
+       if !ok {
+               panic("bad hex")
+       }
+       return r
+}
+
+// These test vectors were taken from
+//   http://csrc.nist.gov/groups/STM/cavp/documents/dss/ecdsatestvectors.zip
+var testVectors = []struct {
+       msg    string
+       Qx, Qy string
+       r, s   string
+       ok     bool
+}{
+       {
+               "09626b45493672e48f3d1226a3aff3201960e577d33a7f72c7eb055302db8fe8ed61685dd036b554942a5737cd1512cdf811ee0c00e6dd2f08c69f08643be396e85dafda664801e772cdb7396868ac47b172245b41986aa2648cb77fbbfa562581be06651355a0c4b090f9d17d8f0ab6cced4e0c9d386cf465a516630f0231bd",
+               "9504b5b82d97a264d8b3735e0568decabc4b6ca275bc53cbadfc1c40",
+               "03426f80e477603b10dee670939623e3da91a94267fc4e51726009ed",
+               "81d3ac609f9575d742028dd496450a58a60eea2dcf8b9842994916e1",
+               "96a8c5f382c992e8f30ccce9af120b067ec1d74678fa8445232f75a5",
+               false,
+       },
+       {
+               "96b2b6536f6df29be8567a72528aceeaccbaa66c66c534f3868ca9778b02faadb182e4ed34662e73b9d52ecbe9dc8e875fc05033c493108b380689ebf47e5b062e6a0cdb3dd34ce5fe347d92768d72f7b9b377c20aea927043b509c078ed2467d7113405d2ddd458811e6faf41c403a2a239240180f1430a6f4330df5d77de37",
+               "851e3100368a22478a0029353045ae40d1d8202ef4d6533cfdddafd8",
+               "205302ac69457dd345e86465afa72ee8c74ca97e2b0b999aec1f10c2",
+               "4450c2d38b697e990721aa2dbb56578d32b4f5aeb3b9072baa955ee0",
+               "e26d4b589166f7b4ba4b1c8fce823fa47aad22f8c9c396b8c6526e12",
+               false,
+       },
+       {
+               "86778dbb4a068a01047a8d245d632f636c11d2ad350740b36fad90428b454ad0f120cb558d12ea5c8a23db595d87543d06d1ef489263d01ee529871eb68737efdb8ff85bc7787b61514bed85b7e01d6be209e0a4eb0db5c8df58a5c5bf706d76cb2bdf7800208639e05b89517155d11688236e6a47ed37d8e5a2b1e0adea338e",
+               "ad5bda09d319a717c1721acd6688d17020b31b47eef1edea57ceeffc",
+               "c8ce98e181770a7c9418c73c63d01494b8b80a41098c5ea50692c984",
+               "de5558c257ab4134e52c19d8db3b224a1899cbd08cc508ce8721d5e9",
+               "745db7af5a477e5046705c0a5eff1f52cb94a79d481f0c5a5e108ecd",
+               true,
+       },
+       {
+               "4bc6ef1958556686dab1e39c3700054a304cbd8f5928603dcd97fafd1f29e69394679b638f71c9344ce6a535d104803d22119f57b5f9477e253817a52afa9bfbc9811d6cc8c8be6b6566c6ef48b439bbb532abe30627548c598867f3861ba0b154dc1c3deca06eb28df8efd28258554b5179883a36fbb1eecf4f93ee19d41e3d",
+               "cc5eea2edf964018bdc0504a3793e4d2145142caa09a72ac5fb8d3e8",
+               "a48d78ae5d08aa725342773975a00d4219cf7a8029bb8cf3c17c374a",
+               "67b861344b4e416d4094472faf4272f6d54a497177fbc5f9ef292836",
+               "1d54f3fcdad795bf3b23408ecbac3e1321d1d66f2e4e3d05f41f7020",
+               false,
+       },
+       {
+               "bb658732acbf3147729959eb7318a2058308b2739ec58907dd5b11cfa3ecf69a1752b7b7d806fe00ec402d18f96039f0b78dbb90a59c4414fb33f1f4e02e4089de4122cd93df5263a95be4d7084e2126493892816e6a5b4ed123cb705bf930c8f67af0fb4514d5769232a9b008a803af225160ce63f675bd4872c4c97b146e5e",
+               "6234c936e27bf141fc7534bfc0a7eedc657f91308203f1dcbd642855",
+               "27983d87ca785ef4892c3591ef4a944b1deb125dd58bd351034a6f84",
+               "e94e05b42d01d0b965ffdd6c3a97a36a771e8ea71003de76c4ecb13f",
+               "1dc6464ffeefbd7872a081a5926e9fc3e66d123f1784340ba17737e9",
+               false,
+       },
+       {
+               "7c00be9123bfa2c4290be1d8bc2942c7f897d9a5b7917e3aabd97ef1aab890f148400a89abd554d19bec9d8ed911ce57b22fbcf6d30ca2115f13ce0a3f569a23bad39ee645f624c49c60dcfc11e7d2be24de9c905596d8f23624d63dc46591d1f740e46f982bfae453f107e80db23545782be23ce43708245896fc54e1ee5c43",
+               "9f3f037282aaf14d4772edffff331bbdda845c3f65780498cde334f1",
+               "8308ee5a16e3bcb721b6bc30000a0419bc1aaedd761be7f658334066",
+               "6381d7804a8808e3c17901e4d283b89449096a8fba993388fa11dc54",
+               "8e858f6b5b253686a86b757bad23658cda53115ac565abca4e3d9f57",
+               false,
+       },
+       {
+               "cffc122a44840dc705bb37130069921be313d8bde0b66201aebc48add028ca131914ef2e705d6bedd19dc6cf9459bbb0f27cdfe3c50483808ffcdaffbeaa5f062e097180f07a40ef4ab6ed03fe07ed6bcfb8afeb42c97eafa2e8a8df469de07317c5e1494c41547478eff4d8c7d9f0f484ad90fedf6e1c35ee68fa73f1691601",
+               "a03b88a10d930002c7b17ca6af2fd3e88fa000edf787dc594f8d4fd4",
+               "e0cf7acd6ddc758e64847fe4df9915ebda2f67cdd5ec979aa57421f5",
+               "387b84dcf37dc343c7d2c5beb82f0bf8bd894b395a7b894565d296c1",
+               "4adc12ce7d20a89ce3925e10491c731b15ddb3f339610857a21b53b4",
+               false,
+       },
+       {
+               "26e0e0cafd85b43d16255908ccfd1f061c680df75aba3081246b337495783052ba06c60f4a486c1591a4048bae11b4d7fec4f161d80bdc9a7b79d23e44433ed625eab280521a37f23dd3e1bdc5c6a6cfaa026f3c45cf703e76dab57add93fe844dd4cda67dc3bddd01f9152579e49df60969b10f09ce9372fdd806b0c7301866",
+               "9a8983c42f2b5a87c37a00458b5970320d247f0c8a88536440173f7d",
+               "15e489ec6355351361900299088cfe8359f04fe0cab78dde952be80c",
+               "929a21baa173d438ec9f28d6a585a2f9abcfc0a4300898668e476dc0",
+               "59a853f046da8318de77ff43f26fe95a92ee296fa3f7e56ce086c872",
+               true,
+       },
+       {
+               "1078eac124f48ae4f807e946971d0de3db3748dd349b14cca5c942560fb25401b2252744f18ad5e455d2d97ed5ae745f55ff509c6c8e64606afe17809affa855c4c4cdcaf6b69ab4846aa5624ed0687541aee6f2224d929685736c6a23906d974d3c257abce1a3fb8db5951b89ecb0cda92b5207d93f6618fd0f893c32cf6a6e",
+               "d6e55820bb62c2be97650302d59d667a411956138306bd566e5c3c2b",
+               "631ab0d64eaf28a71b9cbd27a7a88682a2167cee6251c44e3810894f",
+               "65af72bc7721eb71c2298a0eb4eed3cec96a737cc49125706308b129",
+               "bd5a987c78e2d51598dbd9c34a9035b0069c580edefdacee17ad892a",
+               false,
+       },
+       {
+               "919deb1fdd831c23481dfdb2475dcbe325b04c34f82561ced3d2df0b3d749b36e255c4928973769d46de8b95f162b53cd666cad9ae145e7fcfba97919f703d864efc11eac5f260a5d920d780c52899e5d76f8fe66936ff82130761231f536e6a3d59792f784902c469aa897aabf9a0678f93446610d56d5e0981e4c8a563556b",
+               "269b455b1024eb92d860a420f143ac1286b8cce43031562ae7664574",
+               "baeb6ca274a77c44a0247e5eb12ca72bdd9a698b3f3ae69c9f1aaa57",
+               "cb4ec2160f04613eb0dfe4608486091a25eb12aa4dec1afe91cfb008",
+               "40b01d8cd06589481574f958b98ca08ade9d2a8fe31024375c01bb40",
+               false,
+       },
+       {
+               "6e012361250dacf6166d2dd1aa7be544c3206a9d43464b3fcd90f3f8cf48d08ec099b59ba6fe7d9bdcfaf244120aed1695d8be32d1b1cd6f143982ab945d635fb48a7c76831c0460851a3d62b7209c30cd9c2abdbe3d2a5282a9fcde1a6f418dd23c409bc351896b9b34d7d3a1a63bbaf3d677e612d4a80fa14829386a64b33f",
+               "6d2d695efc6b43b13c14111f2109608f1020e3e03b5e21cfdbc82fcd",
+               "26a4859296b7e360b69cf40be7bd97ceaffa3d07743c8489fc47ca1b",
+               "9a8cb5f2fdc288b7183c5b32d8e546fc2ed1ca4285eeae00c8b572ad",
+               "8c623f357b5d0057b10cdb1a1593dab57cda7bdec9cf868157a79b97",
+               true,
+       },
+       {
+               "bf6bd7356a52b234fe24d25557200971fc803836f6fec3cade9642b13a8e7af10ab48b749de76aada9d8927f9b12f75a2c383ca7358e2566c4bb4f156fce1fd4e87ef8c8d2b6b1bdd351460feb22cdca0437ac10ca5e0abbbce9834483af20e4835386f8b1c96daaa41554ceee56730aac04f23a5c765812efa746051f396566",
+               "14250131b2599939cf2d6bc491be80ddfe7ad9de644387ee67de2d40",
+               "b5dc473b5d014cd504022043c475d3f93c319a8bdcb7262d9e741803",
+               "4f21642f2201278a95339a80f75cc91f8321fcb3c9462562f6cbf145",
+               "452a5f816ea1f75dee4fd514fa91a0d6a43622981966c59a1b371ff8",
+               false,
+       },
+       {
+               "0eb7f4032f90f0bd3cf9473d6d9525d264d14c031a10acd31a053443ed5fe919d5ac35e0be77813071b4062f0b5fdf58ad5f637b76b0b305aec18f82441b6e607b44cdf6e0e3c7c57f24e6fd565e39430af4a6b1d979821ed0175fa03e3125506847654d7e1ae904ce1190ae38dc5919e257bdac2db142a6e7cd4da6c2e83770",
+               "d1f342b7790a1667370a1840255ac5bbbdc66f0bc00ae977d99260ac",
+               "76416cabae2de9a1000b4646338b774baabfa3db4673790771220cdb",
+               "bc85e3fc143d19a7271b2f9e1c04b86146073f3fab4dda1c3b1f35ca",
+               "9a5c70ede3c48d5f43307a0c2a4871934424a3303b815df4bb0f128e",
+               false,
+       },
+       {
+               "5cc25348a05d85e56d4b03cec450128727bc537c66ec3a9fb613c151033b5e86878632249cba83adcefc6c1e35dcd31702929c3b57871cda5c18d1cf8f9650a25b917efaed56032e43b6fc398509f0d2997306d8f26675f3a8683b79ce17128e006aa0903b39eeb2f1001be65de0520115e6f919de902b32c38d691a69c58c92",
+               "7e49a7abf16a792e4c7bbc4d251820a2abd22d9f2fc252a7bf59c9a6",
+               "44236a8fb4791c228c26637c28ae59503a2f450d4cfb0dc42aa843b9",
+               "084461b4050285a1a85b2113be76a17878d849e6bc489f4d84f15cd8",
+               "079b5bddcc4d45de8dbdfd39f69817c7e5afa454a894d03ee1eaaac3",
+               false,
+       },
+       {
+               "1951533ce33afb58935e39e363d8497a8dd0442018fd96dff167b3b23d7206a3ee182a3194765df4768a3284e23b8696c199b4686e670d60c9d782f08794a4bccc05cffffbd1a12acd9eb1cfa01f7ebe124da66ecff4599ea7720c3be4bb7285daa1a86ebf53b042bd23208d468c1b3aa87381f8e1ad63e2b4c2ba5efcf05845",
+               "31945d12ebaf4d81f02be2b1768ed80784bf35cf5e2ff53438c11493",
+               "a62bebffac987e3b9d3ec451eb64c462cdf7b4aa0b1bbb131ceaa0a4",
+               "bc3c32b19e42b710bca5c6aaa128564da3ddb2726b25f33603d2af3c",
+               "ed1a719cc0c507edc5239d76fe50e2306c145ad252bd481da04180c0",
+               false,
+       },
+}
+
+func TestVectors(t *testing.T) {
+       sha := sha1.New()
+
+       for i, test := range testVectors {
+               pub := PublicKey{
+                       Curve: elliptic.P224(),
+                       X:     fromHex(test.Qx),
+                       Y:     fromHex(test.Qy),
+               }
+               msg, _ := hex.DecodeString(test.msg)
+               sha.Reset()
+               sha.Write(msg)
+               hashed := sha.Sum()
+               r := fromHex(test.r)
+               s := fromHex(test.s)
+               if Verify(&pub, hashed, r, s) != test.ok {
+                       t.Errorf("%d: bad result", i)
+               }
+       }
+}
index beac45ca074489e641423491eac1f4fb8a2abbb8..2296e9607776affef2b91d98583d24f9a72c83f1 100644 (file)
@@ -24,6 +24,7 @@ import (
 // See http://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
 type Curve struct {
        P       *big.Int // the order of the underlying field
+       N       *big.Int // the order of the base point
        B       *big.Int // the constant of the curve equation
        Gx, Gy  *big.Int // (x,y) of the base point
        BitSize int      // the size of the underlying field
@@ -315,6 +316,7 @@ func initP224() {
        // See FIPS 186-3, section D.2.2
        p224 = new(Curve)
        p224.P, _ = new(big.Int).SetString("26959946667150639794667015087019630673557916260026308143510066298881", 10)
+       p224.N, _ = new(big.Int).SetString("26959946667150639794667015087019625940457807714424391721682722368061", 10)
        p224.B, _ = new(big.Int).SetString("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4", 16)
        p224.Gx, _ = new(big.Int).SetString("b70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21", 16)
        p224.Gy, _ = new(big.Int).SetString("bd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34", 16)
@@ -325,6 +327,7 @@ func initP256() {
        // See FIPS 186-3, section D.2.3
        p256 = new(Curve)
        p256.P, _ = new(big.Int).SetString("115792089210356248762697446949407573530086143415290314195533631308867097853951", 10)
+       p256.N, _ = new(big.Int).SetString("115792089210356248762697446949407573529996955224135760342422259061068512044369", 10)
        p256.B, _ = new(big.Int).SetString("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 16)
        p256.Gx, _ = new(big.Int).SetString("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 16)
        p256.Gy, _ = new(big.Int).SetString("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 16)
@@ -335,6 +338,7 @@ func initP384() {
        // See FIPS 186-3, section D.2.4
        p384 = new(Curve)
        p384.P, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319", 10)
+       p384.N, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643", 10)
        p384.B, _ = new(big.Int).SetString("b3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef", 16)
        p384.Gx, _ = new(big.Int).SetString("aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7", 16)
        p384.Gy, _ = new(big.Int).SetString("3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f", 16)
@@ -345,6 +349,7 @@ func initP521() {
        // See FIPS 186-3, section D.2.5
        p521 = new(Curve)
        p521.P, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151", 10)
+       p521.N, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397655394245057746333217197532963996371363321113864768612440380340372808892707005449", 10)
        p521.B, _ = new(big.Int).SetString("051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00", 16)
        p521.Gx, _ = new(big.Int).SetString("c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66", 16)
        p521.Gy, _ = new(big.Int).SetString("11839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650", 16)
index 269603ba4985557a0885db6c43a85697cf22a318..57ff3afbfc702609c48950f55a51d1b20b7b9861 100644 (file)
@@ -7,6 +7,7 @@
 package packet
 
 import (
+       "big"
        "crypto/aes"
        "crypto/cast5"
        "crypto/cipher"
@@ -166,10 +167,10 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
        return
 }
 
-// serialiseHeader writes an OpenPGP packet header to w. See RFC 4880, section
+// serializeHeader writes an OpenPGP packet header to w. See RFC 4880, section
 // 4.2.
-func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
-       var buf [5]byte
+func serializeHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
+       var buf [6]byte
        var n int
 
        buf[0] = 0x80 | 0x40 | byte(ptype)
@@ -178,16 +179,16 @@ func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
                n = 2
        } else if length < 8384 {
                length -= 192
-               buf[1] = byte(length >> 8)
+               buf[1] = 192 + byte(length>>8)
                buf[2] = byte(length)
                n = 3
        } else {
-               buf[0] = 255
-               buf[1] = byte(length >> 24)
-               buf[2] = byte(length >> 16)
-               buf[3] = byte(length >> 8)
-               buf[4] = byte(length)
-               n = 5
+               buf[1] = 255
+               buf[2] = byte(length >> 24)
+               buf[3] = byte(length >> 16)
+               buf[4] = byte(length >> 8)
+               buf[5] = byte(length)
+               n = 6
        }
 
        _, err = w.Write(buf[:n])
@@ -371,7 +372,7 @@ func (cipher CipherFunction) new(key []byte) (block cipher.Block) {
 
 // readMPI reads a big integer from r. The bit length returned is the bit
 // length that was specified in r. This is preserved so that the integer can be
-// reserialised exactly.
+// reserialized exactly.
 func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
        var buf [2]byte
        _, err = readFull(r, buf[0:])
@@ -385,7 +386,7 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
        return
 }
 
-// writeMPI serialises a big integer to r.
+// writeMPI serializes a big integer to w.
 func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
        _, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})
        if err == nil {
@@ -393,3 +394,8 @@ func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
        }
        return
 }
+
+// writeBig serializes a *big.Int to w.
+func writeBig(w io.Writer, i *big.Int) os.Error {
+       return writeMPI(w, uint16(i.BitLen()), i.Bytes())
+}
index 6789d2abc792b30967bc999b1b0c1f6b6ec0b5bc..1a4692cd4f5d1b02607fe15d3c5f38480731f499 100644 (file)
@@ -190,3 +190,23 @@ func TestReadHeader(t *testing.T) {
                }
        }
 }
+
+func TestSerializeHeader(t *testing.T) {
+       tag := packetTypePublicKey
+       lengths := []int{0, 1, 2, 64, 192, 193, 8000, 8384, 8385, 10000}
+
+       for _, length := range lengths {
+               buf := bytes.NewBuffer(nil)
+               serializeHeader(buf, tag, length)
+               tag2, length2, _, err := readHeader(buf)
+               if err != nil {
+                       t.Errorf("length %d, err: %s", length, err)
+               }
+               if tag2 != tag {
+                       t.Errorf("length %d, tag incorrect (got %d, want %d)", length, tag2, tag)
+               }
+               if int(length2) != length {
+                       t.Errorf("length %d, length incorrect (got %d)", length, length2)
+               }
+       }
+}
index b22891755e3ea49ea2ea2a6b9385ab9df7a2e543..694482390294e25e575dabc4f09a5ee5c6968edd 100644 (file)
@@ -8,6 +8,7 @@ import (
        "big"
        "bytes"
        "crypto/cipher"
+       "crypto/dsa"
        "crypto/openpgp/error"
        "crypto/openpgp/s2k"
        "crypto/rsa"
@@ -134,7 +135,16 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error {
 }
 
 func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
-       // TODO(agl): support DSA and ECDSA private keys.
+       switch pk.PublicKey.PubKeyAlgo {
+       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoRSAEncryptOnly:
+               return pk.parseRSAPrivateKey(data)
+       case PubKeyAlgoDSA:
+               return pk.parseDSAPrivateKey(data)
+       }
+       panic("impossible")
+}
+
+func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) {
        rsaPub := pk.PublicKey.PublicKey.(*rsa.PublicKey)
        rsaPriv := new(rsa.PrivateKey)
        rsaPriv.PublicKey = *rsaPub
@@ -162,3 +172,22 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
 
        return nil
 }
+
+func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err os.Error) {
+       dsaPub := pk.PublicKey.PublicKey.(*dsa.PublicKey)
+       dsaPriv := new(dsa.PrivateKey)
+       dsaPriv.PublicKey = *dsaPub
+
+       buf := bytes.NewBuffer(data)
+       x, _, err := readMPI(buf)
+       if err != nil {
+               return
+       }
+
+       dsaPriv.X = new(big.Int).SetBytes(x)
+       pk.PrivateKey = dsaPriv
+       pk.Encrypted = false
+       pk.encryptedData = nil
+
+       return nil
+}
index 8866bdaaa9474c280c2fd06bb7703eccd74b2a1d..ebef481fb7f19c6c398b2ba8806b845e10e3e01a 100644 (file)
@@ -11,6 +11,7 @@ import (
        "crypto/rsa"
        "crypto/sha1"
        "encoding/binary"
+       "fmt"
        "hash"
        "io"
        "os"
@@ -178,12 +179,6 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
                return error.InvalidArgumentError("public key cannot generate signatures")
        }
 
-       rsaPublicKey, ok := pk.PublicKey.(*rsa.PublicKey)
-       if !ok {
-               // TODO(agl): support DSA and ECDSA keys.
-               return error.UnsupportedError("non-RSA public key")
-       }
-
        signed.Write(sig.HashSuffix)
        hashBytes := signed.Sum()
 
@@ -191,11 +186,28 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
                return error.SignatureError("hash tag doesn't match")
        }
 
-       err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.Signature)
-       if err != nil {
-               return error.SignatureError("RSA verification failure")
+       if pk.PubKeyAlgo != sig.PubKeyAlgo {
+               return error.InvalidArgumentError("public key and signature use different algorithms")
        }
-       return nil
+
+       switch pk.PubKeyAlgo {
+       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+               rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
+               err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature)
+               if err != nil {
+                       return error.SignatureError("RSA verification failure")
+               }
+               return nil
+       case PubKeyAlgoDSA:
+               dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
+               if !dsa.Verify(dsaPublicKey, hashBytes, sig.DSASigR, sig.DSASigS) {
+                       return error.SignatureError("DSA verification failure")
+               }
+               return nil
+       default:
+               panic("shouldn't happen")
+       }
+       panic("unreachable")
 }
 
 // VerifyKeySignature returns nil iff sig is a valid signature, make by this
@@ -239,9 +251,21 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er
        return pk.VerifySignature(h, sig)
 }
 
+// KeyIdString returns the public key's fingerprint in capital hex
+// (e.g. "6C7EE1B8621CC013").
+func (pk *PublicKey) KeyIdString() string {
+       return fmt.Sprintf("%X", pk.Fingerprint[12:20])
+}
+
+// KeyIdShortString returns the short form of public key's fingerprint
+// in capital hex, as shown by gpg --list-keys (e.g. "621CC013").
+func (pk *PublicKey) KeyIdShortString() string {
+       return fmt.Sprintf("%X", pk.Fingerprint[16:20])
+}
+
 // A parsedMPI is used to store the contents of a big integer, along with the
 // bit length that was specified in the original input. This allows the MPI to
-// be reserialised exactly.
+// be reserialized exactly.
 type parsedMPI struct {
        bytes     []byte
        bitLength uint16
index c015f64aec965005861863fbb45945d48a99224d..069388c14dccc1fc705e2b476faa9d29c3b2939b 100644 (file)
@@ -16,9 +16,11 @@ var pubKeyTests = []struct {
        creationTime   uint32
        pubKeyAlgo     PublicKeyAlgorithm
        keyId          uint64
+       keyIdString    string
+       keyIdShort     string
 }{
-       {rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb},
-       {dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed},
+       {rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb, "A34D7E18C20C31BB", "C20C31BB"},
+       {dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed, "8E8FBE54062F19ED", "062F19ED"},
 }
 
 func TestPublicKeyRead(t *testing.T) {
@@ -46,6 +48,12 @@ func TestPublicKeyRead(t *testing.T) {
                if pk.KeyId != test.keyId {
                        t.Errorf("#%d: bad keyid got:%x want:%x", i, pk.KeyId, test.keyId)
                }
+               if g, e := pk.KeyIdString(), test.keyIdString; g != e {
+                       t.Errorf("#%d: bad KeyIdString got:%q want:%q", i, g, e)
+               }
+               if g, e := pk.KeyIdShortString(), test.keyIdShort; g != e {
+                       t.Errorf("#%d: bad KeyIdShortString got:%q want:%q", i, g, e)
+               }
        }
 }
 
index fd2518ab41eaa92a837b5f36ec58b330eda84d10..719657e76ee09e57b633796badf5144b28c060d2 100644 (file)
@@ -5,7 +5,9 @@
 package packet
 
 import (
+       "big"
        "crypto"
+       "crypto/dsa"
        "crypto/openpgp/error"
        "crypto/openpgp/s2k"
        "crypto/rand"
@@ -29,7 +31,9 @@ type Signature struct {
        // of bad signed data.
        HashTag      [2]byte
        CreationTime uint32 // Unix epoch time
-       Signature    []byte
+
+       RSASignature     []byte
+       DSASigR, DSASigS *big.Int
 
        // The following are optional so are nil when not included in the
        // signature.
@@ -66,7 +70,7 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
        sig.SigType = SignatureType(buf[0])
        sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1])
        switch sig.PubKeyAlgo {
-       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
        default:
                err = error.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
                return
@@ -122,8 +126,20 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
                return
        }
 
-       // We have already checked that the public key algorithm is RSA.
-       sig.Signature, _, err = readMPI(r)
+       switch sig.PubKeyAlgo {
+       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+               sig.RSASignature, _, err = readMPI(r)
+       case PubKeyAlgoDSA:
+               var rBytes, sBytes []byte
+               rBytes, _, err = readMPI(r)
+               sig.DSASigR = new(big.Int).SetBytes(rBytes)
+               if err == nil {
+                       sBytes, _, err = readMPI(r)
+                       sig.DSASigS = new(big.Int).SetBytes(sBytes)
+               }
+       default:
+               panic("unreachable")
+       }
        return
 }
 
@@ -316,8 +332,8 @@ func subpacketLengthLength(length int) int {
        return 5
 }
 
-// serialiseSubpacketLength marshals the given length into to.
-func serialiseSubpacketLength(to []byte, length int) int {
+// serializeSubpacketLength marshals the given length into to.
+func serializeSubpacketLength(to []byte, length int) int {
        if length < 192 {
                to[0] = byte(length)
                return 1
@@ -336,7 +352,7 @@ func serialiseSubpacketLength(to []byte, length int) int {
        return 5
 }
 
-// subpacketsLength returns the serialised length, in bytes, of the given
+// subpacketsLength returns the serialized length, in bytes, of the given
 // subpackets.
 func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
        for _, subpacket := range subpackets {
@@ -349,11 +365,11 @@ func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
        return
 }
 
-// serialiseSubpackets marshals the given subpackets into to.
-func serialiseSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
+// serializeSubpackets marshals the given subpackets into to.
+func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
        for _, subpacket := range subpackets {
                if subpacket.hashed == hashed {
-                       n := serialiseSubpacketLength(to, len(subpacket.contents)+1)
+                       n := serializeSubpacketLength(to, len(subpacket.contents)+1)
                        to[n] = byte(subpacket.subpacketType)
                        to = to[1+n:]
                        n = copy(to, subpacket.contents)
@@ -381,7 +397,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
        }
        sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
        sig.HashSuffix[5] = byte(hashedSubpacketsLen)
-       serialiseSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true)
+       serializeSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true)
        trailer := sig.HashSuffix[l:]
        trailer[0] = 4
        trailer[1] = 0xff
@@ -392,32 +408,66 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
        return
 }
 
-// SignRSA signs a message with an RSA private key. The hash, h, must contain
-// the hash of message to be signed and will be mutated by this function.
-func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
+func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err os.Error) {
        err = sig.buildHashSuffix()
        if err != nil {
                return
        }
 
        h.Write(sig.HashSuffix)
-       digest := h.Sum()
+       digest = h.Sum()
        copy(sig.HashTag[:], digest)
-       sig.Signature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
        return
 }
 
-// Serialize marshals sig to w. SignRSA must have been called first.
+// SignRSA signs a message with an RSA private key. The hash, h, must contain
+// the hash of the message to be signed and will be mutated by this function.
+// On success, the signature is stored in sig. Call Serialize to write it out.
+func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
+       digest, err := sig.signPrepareHash(h)
+       if err != nil {
+               return
+       }
+       sig.RSASignature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
+       return
+}
+
+// SignDSA signs a message with a DSA private key. The hash, h, must contain
+// the hash of the message to be signed and will be mutated by this function.
+// On success, the signature is stored in sig. Call Serialize to write it out.
+func (sig *Signature) SignDSA(h hash.Hash, priv *dsa.PrivateKey) (err os.Error) {
+       digest, err := sig.signPrepareHash(h)
+       if err != nil {
+               return
+       }
+       sig.DSASigR, sig.DSASigS, err = dsa.Sign(rand.Reader, priv, digest)
+       return
+}
+
+// Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
 func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
-       if sig.Signature == nil {
-               return error.InvalidArgumentError("Signature: need to call SignRSA before Serialize")
+       if sig.RSASignature == nil && sig.DSASigR == nil {
+               return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
+       }
+
+       sigLength := 0
+       switch sig.PubKeyAlgo {
+       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+               sigLength = len(sig.RSASignature)
+       case PubKeyAlgoDSA:
+               sigLength = 2 /* MPI length */
+               sigLength += (sig.DSASigR.BitLen() + 7) / 8
+               sigLength += 2 /* MPI length */
+               sigLength += (sig.DSASigS.BitLen() + 7) / 8
+       default:
+               panic("impossible")
        }
 
        unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
        length := len(sig.HashSuffix) - 6 /* trailer not included */ +
                2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
-               2 /* hash tag */ + 2 /* length of signature MPI */ + len(sig.Signature)
-       err = serialiseHeader(w, packetTypeSignature, length)
+               2 /* hash tag */ + 2 /* length of signature MPI */ + sigLength
+       err = serializeHeader(w, packetTypeSignature, length)
        if err != nil {
                return
        }
@@ -430,7 +480,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
        unhashedSubpackets := make([]byte, 2+unhashedSubpacketsLen)
        unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8)
        unhashedSubpackets[1] = byte(unhashedSubpacketsLen)
-       serialiseSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false)
+       serializeSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false)
 
        _, err = w.Write(unhashedSubpackets)
        if err != nil {
@@ -440,7 +490,19 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
        if err != nil {
                return
        }
-       return writeMPI(w, 8*uint16(len(sig.Signature)), sig.Signature)
+
+       switch sig.PubKeyAlgo {
+       case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
+               err = writeMPI(w, 8*uint16(len(sig.RSASignature)), sig.RSASignature)
+       case PubKeyAlgoDSA:
+               err = writeBig(w, sig.DSASigR)
+               if err == nil {
+                       err = writeBig(w, sig.DSASigS)
+               }
+       default:
+               panic("impossible")
+       }
+       return
 }
 
 // outputSubpacket represents a subpacket to be marshaled.
index 58199e1329738f0dc5b3548be3dec22d81e20367..6218d9990dd381cefadf82b2416bb348bc3be99a 100644 (file)
@@ -44,6 +44,17 @@ func TestReadPrivateKeyRing(t *testing.T) {
        }
 }
 
+func TestReadDSAKey(t *testing.T) {
+       kring, err := ReadKeyRing(readerFromHex(dsaTestKeyHex))
+       if err != nil {
+               t.Error(err)
+               return
+       }
+       if len(kring) != 1 || uint32(kring[0].PrimaryKey.KeyId) != 0x0CCC0360 {
+               t.Errorf("bad parse: %#v", kring)
+       }
+}
+
 func TestGetKeyById(t *testing.T) {
        kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex))
 
@@ -192,7 +203,7 @@ func TestSymmetricallyEncrypted(t *testing.T) {
        }
 }
 
-func testDetachedSignature(t *testing.T, kring KeyRing, signature io.Reader, sigInput, tag string) {
+func testDetachedSignature(t *testing.T, kring KeyRing, signature io.Reader, sigInput, tag string, expectedSignerKeyId uint64) {
        signed := bytes.NewBufferString(sigInput)
        signer, err := CheckDetachedSignature(kring, signed, signature)
        if err != nil {
@@ -203,7 +214,6 @@ func testDetachedSignature(t *testing.T, kring KeyRing, signature io.Reader, sig
                t.Errorf("%s: signer is nil", tag)
                return
        }
-       expectedSignerKeyId := uint64(0xa34d7e18c20c31bb)
        if signer.PrimaryKey.KeyId != expectedSignerKeyId {
                t.Errorf("%s: wrong signer got:%x want:%x", tag, signer.PrimaryKey.KeyId, expectedSignerKeyId)
        }
@@ -211,10 +221,18 @@ func testDetachedSignature(t *testing.T, kring KeyRing, signature io.Reader, sig
 
 func TestDetachedSignature(t *testing.T) {
        kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex))
-       testDetachedSignature(t, kring, readerFromHex(detachedSignatureHex), signedInput, "binary")
-       testDetachedSignature(t, kring, readerFromHex(detachedSignatureTextHex), signedInput, "text")
+       testDetachedSignature(t, kring, readerFromHex(detachedSignatureHex), signedInput, "binary", testKey1KeyId)
+       testDetachedSignature(t, kring, readerFromHex(detachedSignatureTextHex), signedInput, "text", testKey1KeyId)
 }
 
+func TestDetachedSignatureDSA(t *testing.T) {
+       kring, _ := ReadKeyRing(readerFromHex(dsaTestKeyHex))
+       testDetachedSignature(t, kring, readerFromHex(detachedSignatureDSAHex), signedInput, "binary", testKey3KeyId)
+}
+
+const testKey1KeyId = 0xA34D7E18C20C31BB
+const testKey3KeyId = 0x338934250CCC0360
+
 const signedInput = "Signed message\nline 2\nline 3\n"
 const signedTextInput = "Signed message\r\nline 2\r\nline 3\r\n"
 
@@ -224,6 +242,8 @@ const detachedSignatureHex = "889c04000102000605024d449cd1000a0910a34d7e18c20c31
 
 const detachedSignatureTextHex = "889c04010102000605024d449d21000a0910a34d7e18c20c31bbc8c60400a24fbef7342603a41cb1165767bd18985d015fb72fe05db42db36cfb2f1d455967f1e491194fbf6cf88146222b23bf6ffbd50d17598d976a0417d3192ff9cc0034fd00f287b02e90418bbefe609484b09231e4e7a5f3562e199bf39909ab5276c4d37382fe088f6b5c3426fc1052865da8b3ab158672d58b6264b10823dc4b39"
 
+const detachedSignatureDSAHex = "884604001102000605024d6c4eac000a0910338934250ccc0360f18d00a087d743d6405ed7b87755476629600b8b694a39e900a0abff8126f46faf1547c1743c37b21b4ea15b8f83"
+
 const testKeys1And2Hex = "988d044d3c5c10010400b1d13382944bd5aba23a4312968b5095d14f947f600eb478e14a6fcb16b0e0cac764884909c020bc495cfcc39a935387c661507bdb236a0612fb582cac3af9b29cc2c8c70090616c41b662f4da4c1201e195472eb7f4ae1ccbcbf9940fe21d985e379a5563dde5b9a23d35f1cfaa5790da3b79db26f23695107bfaca8e7b5bcd0011010001b41054657374204b6579203120285253412988b804130102002205024d3c5c10021b03060b090807030206150802090a0b0416020301021e01021780000a0910a34d7e18c20c31bbb5b304009cc45fe610b641a2c146331be94dade0a396e73ca725e1b25c21708d9cab46ecca5ccebc23055879df8f99eea39b377962a400f2ebdc36a7c99c333d74aeba346315137c3ff9d0a09b0273299090343048afb8107cf94cbd1400e3026f0ccac7ecebbc4d78588eb3e478fe2754d3ca664bcf3eac96ca4a6b0c8d7df5102f60f6b0020003b88d044d3c5c10010400b201df61d67487301f11879d514f4248ade90c8f68c7af1284c161098de4c28c2850f1ec7b8e30f959793e571542ffc6532189409cb51c3d30dad78c4ad5165eda18b20d9826d8707d0f742e2ab492103a85bbd9ddf4f5720f6de7064feb0d39ee002219765bb07bcfb8b877f47abe270ddeda4f676108cecb6b9bb2ad484a4f0011010001889f04180102000905024d3c5c10021b0c000a0910a34d7e18c20c31bb1a03040085c8d62e16d05dc4e9dad64953c8a2eed8b6c12f92b1575eeaa6dcf7be9473dd5b24b37b6dffbb4e7c99ed1bd3cb11634be19b3e6e207bed7505c7ca111ccf47cb323bf1f8851eb6360e8034cbff8dd149993c959de89f8f77f38e7e98b8e3076323aa719328e2b408db5ec0d03936efd57422ba04f925cdc7b4c1af7590e40ab0020003988d044d3c5c33010400b488c3e5f83f4d561f317817538d9d0397981e9aef1321ca68ebfae1cf8b7d388e19f4b5a24a82e2fbbf1c6c26557a6c5845307a03d815756f564ac7325b02bc83e87d5480a8fae848f07cb891f2d51ce7df83dcafdc12324517c86d472cc0ee10d47a68fd1d9ae49a6c19bbd36d82af597a0d88cc9c49de9df4e696fc1f0b5d0011010001b42754657374204b6579203220285253412c20656e637279707465642070726976617465206b65792988b804130102002205024d3c5c33021b03060b090807030206150802090a0b0416020301021e01021780000a0910d4984f961e35246b98940400908a73b6a6169f700434f076c6c79015a49bee37130eaf23aaa3cfa9ce60bfe4acaa7bc95f1146ada5867e0079babb38804891f4f0b8ebca57a86b249dee786161a755b7a342e68ccf3f78ed6440a93a6626beb9a37aa66afcd4f888790cb4bb46d94a4ae3eb3d7d3e6b00f6bfec940303e89ec5b32a1eaaacce66497d539328b0020003b88d044d3c5c33010400a4e913f9442abcc7f1804ccab27d2f787ffa592077ca935a8bb23165bd8d57576acac647cc596b2c3f814518cc8c82953c7a4478f32e0cf645630a5ba38d9618ef2bc3add69d459ae3dece5cab778938d988239f8c5ae437807075e06c828019959c644ff05ef6a5a1dab72227c98e3a040b0cf219026640698d7a13d8538a570011010001889f04180102000905024d3c5c33021b0c000a0910d4984f961e35246b26c703ff7ee29ef53bc1ae1ead533c408fa136db508434e233d6e62be621e031e5940bbd4c08142aed0f82217e7c3e1ec8de574bc06ccf3c36633be41ad78a9eacd209f861cae7b064100758545cc9dd83db71806dc1cfd5fb9ae5c7474bba0c19c44034ae61bae5eca379383339dece94ff56ff7aa44a582f3e5c38f45763af577c0934b0020003"
 
 const testKeys1And2PrivateHex = "9501d8044d3c5c10010400b1d13382944bd5aba23a4312968b5095d14f947f600eb478e14a6fcb16b0e0cac764884909c020bc495cfcc39a935387c661507bdb236a0612fb582cac3af9b29cc2c8c70090616c41b662f4da4c1201e195472eb7f4ae1ccbcbf9940fe21d985e379a5563dde5b9a23d35f1cfaa5790da3b79db26f23695107bfaca8e7b5bcd00110100010003ff4d91393b9a8e3430b14d6209df42f98dc927425b881f1209f319220841273a802a97c7bdb8b3a7740b3ab5866c4d1d308ad0d3a79bd1e883aacf1ac92dfe720285d10d08752a7efe3c609b1d00f17f2805b217be53999a7da7e493bfc3e9618fd17018991b8128aea70a05dbce30e4fbe626aa45775fa255dd9177aabf4df7cf0200c1ded12566e4bc2bb590455e5becfb2e2c9796482270a943343a7835de41080582c2be3caf5981aa838140e97afa40ad652a0b544f83eb1833b0957dce26e47b0200eacd6046741e9ce2ec5beb6fb5e6335457844fb09477f83b050a96be7da043e17f3a9523567ed40e7a521f818813a8b8a72209f1442844843ccc7eb9805442570200bdafe0438d97ac36e773c7162028d65844c4d463e2420aa2228c6e50dc2743c3d6c72d0d782a5173fe7be2169c8a9f4ef8a7cf3e37165e8c61b89c346cdc6c1799d2b41054657374204b6579203120285253412988b804130102002205024d3c5c10021b03060b090807030206150802090a0b0416020301021e01021780000a0910a34d7e18c20c31bbb5b304009cc45fe610b641a2c146331be94dade0a396e73ca725e1b25c21708d9cab46ecca5ccebc23055879df8f99eea39b377962a400f2ebdc36a7c99c333d74aeba346315137c3ff9d0a09b0273299090343048afb8107cf94cbd1400e3026f0ccac7ecebbc4d78588eb3e478fe2754d3ca664bcf3eac96ca4a6b0c8d7df5102f60f6b00200009d01d8044d3c5c10010400b201df61d67487301f11879d514f4248ade90c8f68c7af1284c161098de4c28c2850f1ec7b8e30f959793e571542ffc6532189409cb51c3d30dad78c4ad5165eda18b20d9826d8707d0f742e2ab492103a85bbd9ddf4f5720f6de7064feb0d39ee002219765bb07bcfb8b877f47abe270ddeda4f676108cecb6b9bb2ad484a4f00110100010003fd17a7490c22a79c59281fb7b20f5e6553ec0c1637ae382e8adaea295f50241037f8997cf42c1ce26417e015091451b15424b2c59eb8d4161b0975630408e394d3b00f88d4b4e18e2cc85e8251d4753a27c639c83f5ad4a571c4f19d7cd460b9b73c25ade730c99df09637bd173d8e3e981ac64432078263bb6dc30d3e974150dd0200d0ee05be3d4604d2146fb0457f31ba17c057560785aa804e8ca5530a7cd81d3440d0f4ba6851efcfd3954b7e68908fc0ba47f7ac37bf559c6c168b70d3a7c8cd0200da1c677c4bce06a068070f2b3733b0a714e88d62aa3f9a26c6f5216d48d5c2b5624144f3807c0df30be66b3268eeeca4df1fbded58faf49fc95dc3c35f134f8b01fd1396b6c0fc1b6c4f0eb8f5e44b8eace1e6073e20d0b8bc5385f86f1cf3f050f66af789f3ef1fc107b7f4421e19e0349c730c68f0a226981f4e889054fdb4dc149e8e889f04180102000905024d3c5c10021b0c000a0910a34d7e18c20c31bb1a03040085c8d62e16d05dc4e9dad64953c8a2eed8b6c12f92b1575eeaa6dcf7be9473dd5b24b37b6dffbb4e7c99ed1bd3cb11634be19b3e6e207bed7505c7ca111ccf47cb323bf1f8851eb6360e8034cbff8dd149993c959de89f8f77f38e7e98b8e3076323aa719328e2b408db5ec0d03936efd57422ba04f925cdc7b4c1af7590e40ab00200009501fe044d3c5c33010400b488c3e5f83f4d561f317817538d9d0397981e9aef1321ca68ebfae1cf8b7d388e19f4b5a24a82e2fbbf1c6c26557a6c5845307a03d815756f564ac7325b02bc83e87d5480a8fae848f07cb891f2d51ce7df83dcafdc12324517c86d472cc0ee10d47a68fd1d9ae49a6c19bbd36d82af597a0d88cc9c49de9df4e696fc1f0b5d0011010001fe030302e9030f3c783e14856063f16938530e148bc57a7aa3f3e4f90df9dceccdc779bc0835e1ad3d006e4a8d7b36d08b8e0de5a0d947254ecfbd22037e6572b426bcfdc517796b224b0036ff90bc574b5509bede85512f2eefb520fb4b02aa523ba739bff424a6fe81c5041f253f8d757e69a503d3563a104d0d49e9e890b9d0c26f96b55b743883b472caa7050c4acfd4a21f875bdf1258d88bd61224d303dc9df77f743137d51e6d5246b88c406780528fd9a3e15bab5452e5b93970d9dcc79f48b38651b9f15bfbcf6da452837e9cc70683d1bdca94507870f743e4ad902005812488dd342f836e72869afd00ce1850eea4cfa53ce10e3608e13d3c149394ee3cbd0e23d018fcbcb6e2ec5a1a22972d1d462ca05355d0d290dd2751e550d5efb38c6c89686344df64852bf4ff86638708f644e8ec6bd4af9b50d8541cb91891a431326ab2e332faa7ae86cfb6e0540aa63160c1e5cdd5a4add518b303fff0a20117c6bc77f7cfbaf36b04c865c6c2b42754657374204b6579203220285253412c20656e637279707465642070726976617465206b65792988b804130102002205024d3c5c33021b03060b090807030206150802090a0b0416020301021e01021780000a0910d4984f961e35246b98940400908a73b6a6169f700434f076c6c79015a49bee37130eaf23aaa3cfa9ce60bfe4acaa7bc95f1146ada5867e0079babb38804891f4f0b8ebca57a86b249dee786161a755b7a342e68ccf3f78ed6440a93a6626beb9a37aa66afcd4f888790cb4bb46d94a4ae3eb3d7d3e6b00f6bfec940303e89ec5b32a1eaaacce66497d539328b00200009d01fe044d3c5c33010400a4e913f9442abcc7f1804ccab27d2f787ffa592077ca935a8bb23165bd8d57576acac647cc596b2c3f814518cc8c82953c7a4478f32e0cf645630a5ba38d9618ef2bc3add69d459ae3dece5cab778938d988239f8c5ae437807075e06c828019959c644ff05ef6a5a1dab72227c98e3a040b0cf219026640698d7a13d8538a570011010001fe030302e9030f3c783e148560f936097339ae381d63116efcf802ff8b1c9360767db5219cc987375702a4123fd8657d3e22700f23f95020d1b261eda5257e9a72f9a918e8ef22dd5b3323ae03bbc1923dd224db988cadc16acc04b120a9f8b7e84da9716c53e0334d7b66586ddb9014df604b41be1e960dcfcbc96f4ed150a1a0dd070b9eb14276b9b6be413a769a75b519a53d3ecc0c220e85cd91ca354d57e7344517e64b43b6e29823cbd87eae26e2b2e78e6dedfbb76e3e9f77bcb844f9a8932eb3db2c3f9e44316e6f5d60e9e2a56e46b72abe6b06dc9a31cc63f10023d1f5e12d2a3ee93b675c96f504af0001220991c88db759e231b3320dcedf814dcf723fd9857e3d72d66a0f2af26950b915abdf56c1596f46a325bf17ad4810d3535fb02a259b247ac3dbd4cc3ecf9c51b6c07cebb009c1506fba0a89321ec8683e3fd009a6e551d50243e2d5092fefb3321083a4bad91320dc624bd6b5dddf93553e3d53924c05bfebec1fb4bd47e89a1a889f04180102000905024d3c5c33021b0c000a0910d4984f961e35246b26c703ff7ee29ef53bc1ae1ead533c408fa136db508434e233d6e62be621e031e5940bbd4c08142aed0f82217e7c3e1ec8de574bc06ccf3c36633be41ad78a9eacd209f861cae7b064100758545cc9dd83db71806dc1cfd5fb9ae5c7474bba0c19c44034ae61bae5eca379383339dece94ff56ff7aa44a582f3e5c38f45763af577c0934b0020000"
@@ -235,3 +255,7 @@ const signedTextMessageHex = "a3019bc0cbccc8c4b8d8b74ee2108fe16ec6d36a250cbece0c
 const signedEncryptedMessageHex = "848c032a67d68660df41c70103ff5789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8d2c03b018bd210b1d3791e1aba74b0f1034e122ab72e760492c192383cf5e20b5628bd043272d63df9b923f147eb6091cd897553204832aba48fec54aa447547bb16305a1024713b90e77fd0065f1918271947549205af3c74891af22ee0b56cd29bfec6d6e351901cd4ab3ece7c486f1e32a792d4e474aed98ee84b3f591c7dff37b64e0ecd68fd036d517e412dcadf85840ce184ad7921ad446c4ee28db80447aea1ca8d4f574db4d4e37688158ddd19e14ee2eab4873d46947d65d14a23e788d912cf9a19624ca7352469b72a83866b7c23cb5ace3deab3c7018061b0ba0f39ed2befe27163e5083cf9b8271e3e3d52cc7ad6e2a3bd81d4c3d7022f8d"
 
 const symmetricallyEncryptedCompressedHex = "8c0d04030302eb4a03808145d0d260c92f714339e13de5a79881216431925bf67ee2898ea61815f07894cd0703c50d0a76ef64d482196f47a8bc729af9b80bb6"
+
+const dsaTestKeyHex = "9901a2044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794"
+
+const dsaTestKeyPrivateHex = "9501bb044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4d00009f592e0619d823953577d4503061706843317e4fee083db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794"
index 1a2e2bf040782b516ed7d1028f7745a19fe5e0de..ef7b11230a987ac55f230504d6da617e7bcaa748 100644 (file)
@@ -6,6 +6,7 @@ package openpgp
 
 import (
        "crypto"
+       "crypto/dsa"
        "crypto/openpgp/armor"
        "crypto/openpgp/error"
        "crypto/openpgp/packet"
@@ -39,7 +40,7 @@ func DetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error {
 // ArmoredDetachSignText signs message (after canonicalising the line endings)
 // with the private key from signer (which must already have been decrypted)
 // and writes an armored signature to w.
-func SignTextDetachedArmored(w io.Writer, signer *Entity, message io.Reader) os.Error {
+func ArmoredDetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error {
        return armoredDetachSign(w, signer, message, packet.SigTypeText)
 }
 
@@ -80,6 +81,9 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
        case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSASignOnly:
                priv := signer.PrivateKey.PrivateKey.(*rsa.PrivateKey)
                err = sig.SignRSA(h, priv)
+       case packet.PubKeyAlgoDSA:
+               priv := signer.PrivateKey.PrivateKey.(*dsa.PrivateKey)
+               err = sig.SignDSA(h, priv)
        default:
                err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
        }
index 33e8809f224bc84f531c33c185e5b4e917e85c08..42cd0d27f850f512bc153ec7fbebc1d2d0aa314b 100644 (file)
@@ -18,7 +18,7 @@ func TestSignDetached(t *testing.T) {
                t.Error(err)
        }
 
-       testDetachedSignature(t, kring, out, signedInput, "check")
+       testDetachedSignature(t, kring, out, signedInput, "check", testKey1KeyId)
 }
 
 func TestSignTextDetached(t *testing.T) {
@@ -30,5 +30,17 @@ func TestSignTextDetached(t *testing.T) {
                t.Error(err)
        }
 
-       testDetachedSignature(t, kring, out, signedInput, "check")
+       testDetachedSignature(t, kring, out, signedInput, "check", testKey1KeyId)
+}
+
+func TestSignDetachedDSA(t *testing.T) {
+       kring, _ := ReadKeyRing(readerFromHex(dsaTestKeyPrivateHex))
+       out := bytes.NewBuffer(nil)
+       message := bytes.NewBufferString(signedInput)
+       err := DetachSign(out, kring[0], message)
+       if err != nil {
+               t.Error(err)
+       }
+
+       testDetachedSignature(t, kring, out, signedInput, "check", testKey3KeyId)
 }
index 7135f3d0f716853fe89c9364ac5695eb0d956f06..81b5a07446ea644b0e4cb5047fe8b23e343cdca4 100644 (file)
@@ -7,6 +7,7 @@ package tls
 import (
        "crypto/rand"
        "crypto/rsa"
+       "crypto/x509"
        "io"
        "io/ioutil"
        "sync"
@@ -95,6 +96,9 @@ type ConnectionState struct {
        HandshakeComplete  bool
        CipherSuite        uint16
        NegotiatedProtocol string
+
+       // the certificate chain that was presented by the other side
+       PeerCertificates []*x509.Certificate
 }
 
 // A Config structure is used to configure a TLS client or server. After one
index d203e8d5169afa1055937886d55da3c46c19f82b..1e6fe60aec2a4ba630bce7f1967fa2ae81f975e9 100644 (file)
@@ -762,6 +762,7 @@ func (c *Conn) ConnectionState() ConnectionState {
        if c.handshakeComplete {
                state.NegotiatedProtocol = c.clientProtocol
                state.CipherSuite = c.cipherSuite
+               state.PeerCertificates = c.peerCertificates
        }
 
        return state
@@ -776,15 +777,6 @@ func (c *Conn) OCSPResponse() []byte {
        return c.ocspResponse
 }
 
-// PeerCertificates returns the certificate chain that was presented by the
-// other side.
-func (c *Conn) PeerCertificates() []*x509.Certificate {
-       c.handshakeMutex.Lock()
-       defer c.handshakeMutex.Unlock()
-
-       return c.peerCertificates
-}
-
 // VerifyHostname checks that the peer certificate chain is valid for
 // connecting to host.  If so, it returns nil; if not, it returns an os.Error
 // describing the problem.
index 3e0c6393893c0518fe86c35181093ec39aaec8f4..ee77f949fc038b4d01d3d053a0e775c98b5b8ede 100644 (file)
@@ -25,7 +25,7 @@ func main() {
 
        priv, err := rsa.GenerateKey(rand.Reader, 1024)
        if err != nil {
-               log.Exitf("failed to generate private key: %s", err)
+               log.Fatalf("failed to generate private key: %s", err)
                return
        }
 
@@ -46,13 +46,13 @@ func main() {
 
        derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
        if err != nil {
-               log.Exitf("Failed to create certificate: %s", err)
+               log.Fatalf("Failed to create certificate: %s", err)
                return
        }
 
        certOut, err := os.Open("cert.pem", os.O_WRONLY|os.O_CREAT, 0644)
        if err != nil {
-               log.Exitf("failed to open cert.pem for writing: %s", err)
+               log.Fatalf("failed to open cert.pem for writing: %s", err)
                return
        }
        pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
index 7caf3a21a4f9147872f16da8194295b3bfc3a518..49f0a5361fbf5410ae0e0e25414740e8ae3529a9 100644 (file)
@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
        return nil, os.NewError("debug/proc not implemented on OS X")
 }
 
-func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
+func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
        return Attach(0)
 }
index f6474ce80ca99ba65a14597dc91337dd284cc34f..4df07c365afe5ac9a4d712f2cc02c1a728113ad2 100644 (file)
@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
        return nil, os.NewError("debug/proc not implemented on FreeBSD")
 }
 
-func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
+func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
        return Attach(0)
 }
index f0cc43a108eeea265c8872fdc341a3e6ba5868db..6890a2221ef4b3496c6d7b3f32963afac8b906f6 100644 (file)
@@ -1279,25 +1279,31 @@ func Attach(pid int) (Process, os.Error) {
        return p, nil
 }
 
-// ForkExec forks the current process and execs argv0, stopping the
-// new process after the exec syscall.  See os.ForkExec for additional
+// StartProcess forks the current process and execs argv0, stopping the
+// new process after the exec syscall.  See os.StartProcess for additional
 // details.
-func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
+func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
+       sysattr := &syscall.ProcAttr{
+               Dir:    attr.Dir,
+               Env:    attr.Env,
+               Ptrace: true,
+       }
        p := newProcess(-1)
 
        // Create array of integer (system) fds.
-       intfd := make([]int, len(fd))
-       for i, f := range fd {
+       intfd := make([]int, len(attr.Files))
+       for i, f := range attr.Files {
                if f == nil {
                        intfd[i] = -1
                } else {
                        intfd[i] = f.Fd()
                }
        }
+       sysattr.Files = intfd
 
        // Fork from the monitor thread so we get the right tracer pid.
        err := p.do(func() os.Error {
-               pid, errno := syscall.PtraceForkExec(argv0, argv, envv, dir, intfd)
+               pid, _, errno := syscall.StartProcess(argv0, argv, sysattr)
                if errno != 0 {
                        return &os.PathError{"fork/exec", argv0, os.Errno(errno)}
                }
index dc22faef81bd10825424660016a9b647dcd391c5..661474b67aaaabe246ab0a2d0be3fa842580a4f9 100644 (file)
@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
        return nil, os.NewError("debug/proc not implemented on windows")
 }
 
-func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
+func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
        return Attach(0)
 }
index 80f6f3c7dd489d28535e026501c2b4852c26ff35..44e3b65bec87f24c18e6d5733785a657c41d8809 100644 (file)
@@ -75,17 +75,19 @@ func modeToFiles(mode, fd int) (*os.File, *os.File, os.Error) {
 
 // Run starts the named binary running with
 // arguments argv and environment envv.
+// If the dir argument is not empty, the child changes
+// into the directory before executing the binary.
 // It returns a pointer to a new Cmd representing
 // the command or an error.
 //
-// The parameters stdin, stdout, and stderr
+// The arguments stdin, stdout, and stderr
 // specify how to handle standard input, output, and error.
 // The choices are DevNull (connect to /dev/null),
 // PassThrough (connect to the current process's standard stream),
 // Pipe (connect to an operating system pipe), and
 // MergeWithStdout (only for standard error; use the same
 // file descriptor as was used for standard output).
-// If a parameter is Pipe, then the corresponding field (Stdin, Stdout, Stderr)
+// If an argument is Pipe, then the corresponding field (Stdin, Stdout, Stderr)
 // of the returned Cmd is the other end of the pipe.
 // Otherwise the field in Cmd is nil.
 func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int) (c *Cmd, err os.Error) {
@@ -105,7 +107,7 @@ func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int
        }
 
        // Run command.
-       c.Process, err = os.StartProcess(name, argv, envv, dir, fd[0:])
+       c.Process, err = os.StartProcess(name, argv, &os.ProcAttr{Dir: dir, Files: fd[:], Env: envv})
        if err != nil {
                goto Error
        }
index 3a3d3b1a53ed05c706eba2926e6dd7c76ce12df7..5e37b99eeca0f8cd4f2749aa7f3847fe41b1bee4 100644 (file)
@@ -118,3 +118,55 @@ func TestAddEnvVar(t *testing.T) {
                t.Fatal("close:", err)
        }
 }
+
+var tryargs = []string{
+       `2`,
+       `2 `,
+       "2 \t",
+       `2" "`,
+       `2 ab `,
+       `2 "ab" `,
+       `2 \ `,
+       `2 \\ `,
+       `2 \" `,
+       `2 \`,
+       `2\`,
+       `2"`,
+       `2\"`,
+       `2 "`,
+       `2 \"`,
+       ``,
+       `2 ^ `,
+       `2 \^`,
+}
+
+func TestArgs(t *testing.T) {
+       for _, a := range tryargs {
+               argv := []string{
+                       "awk",
+                       `BEGIN{printf("%s|%s|%s",ARGV[1],ARGV[2],ARGV[3])}`,
+                       "/dev/null",
+                       a,
+                       "EOF",
+               }
+               exe, err := LookPath(argv[0])
+               if err != nil {
+                       t.Fatal("run:", err)
+               }
+               cmd, err := Run(exe, argv, nil, "", DevNull, Pipe, DevNull)
+               if err != nil {
+                       t.Fatal("run:", err)
+               }
+               buf, err := ioutil.ReadAll(cmd.Stdout)
+               if err != nil {
+                       t.Fatal("read:", err)
+               }
+               expect := "/dev/null|" + a + "|EOF"
+               if string(buf) != expect {
+                       t.Errorf("read: got %q expect %q", buf, expect)
+               }
+               if err = cmd.Close(); err != nil {
+                       t.Fatal("close:", err)
+               }
+       }
+}
index 5c5d4338a1d87598b08ceaa2da5891f4531052d0..f6b7c1cda941ac4ef4798a7cbefb9f9f42064d63 100644 (file)
@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) {
        case *ast.SwitchStmt:
                a.compileSwitchStmt(s)
 
-       case *ast.TypeCaseClause:
-               notimpl = true
-
        case *ast.TypeSwitchStmt:
                notimpl = true
 
@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
                        a.diagAt(clause.Pos(), "switch statement must contain case clauses")
                        continue
                }
-               if clause.Values == nil {
+               if clause.List == nil {
                        if hasDefault {
                                a.diagAt(clause.Pos(), "switch statement contains more than one default case")
                        }
                        hasDefault = true
                } else {
-                       ncases += len(clause.Values)
+                       ncases += len(clause.List)
                }
        }
 
@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
                if !ok {
                        continue
                }
-               for _, v := range clause.Values {
+               for _, v := range clause.List {
                        e := condbc.compileExpr(condbc.block, false, v)
                        switch {
                        case e == nil:
@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
 
                // Save jump PC's
                pc := a.nextPC()
-               if clause.Values != nil {
-                       for _ = range clause.Values {
+               if clause.List != nil {
+                       for _ = range clause.List {
                                casePCs[i] = &pc
                                i++
                        }
index 4a883ef5ee79b2460f8bee76f5a613cb6b4db055..a8a3e16204121e4db0e58d5d2c85b5a05d801893 100644 (file)
@@ -27,7 +27,7 @@ var stmtTests = []test{
        CErr("i, u := 1, 2", atLeastOneDecl),
        Val2("i, x := 2, f", "i", 2, "x", 1.0),
        // Various errors
-       CErr("1 := 2", "left side of := must be a name"),
+       CErr("1 := 2", "expected identifier"),
        CErr("c, a := 1, 1", "cannot assign"),
        // Unpacking
        Val2("x, y := oneTwo()", "x", 1, "y", 2),
index 4f67032d0c3085a029149d06037c4e33543f1f36..9920ff6b883a1c354a0abd41fb5ea52df475e864 100644 (file)
@@ -160,7 +160,7 @@ func cmdLoad(args []byte) os.Error {
                } else {
                        fname = parts[0]
                }
-               tproc, err = proc.ForkExec(fname, parts, os.Environ(), "", []*os.File{os.Stdin, os.Stdout, os.Stderr})
+               tproc, err = proc.StartProcess(fname, parts, &os.ProcAttr{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}})
                if err != nil {
                        return err
                }
index b1f0f6c1b81edd332acc24ae2dad371fe8fe610d..ed6cff78db40fbf9eda4ae796f9de1c279977aa7 100644 (file)
@@ -269,7 +269,7 @@ func Iter() <-chan KeyValue {
 }
 
 func expvarHandler(w http.ResponseWriter, r *http.Request) {
-       w.SetHeader("content-type", "application/json; charset=utf-8")
+       w.Header().Set("Content-Type", "application/json; charset=utf-8")
        fmt.Fprintf(w, "{\n")
        first := true
        for name, value := range vars {
index be972057ed74c08720e1bc6dd727fcd613223d58..14f4d522c6a4ea23c70d12808c03834cc175e830 100644 (file)
@@ -56,7 +56,7 @@
 
                flag.Bool(...)  // global options
                flag.Parse()  // parse leading command
-               subcmd := flag.Args(0)
+               subcmd := flag.Arg[0]
                switch subcmd {
                        // add per-subcommand options
                }
@@ -68,6 +68,7 @@ package flag
 import (
        "fmt"
        "os"
+       "sort"
        "strconv"
 )
 
@@ -205,16 +206,34 @@ type allFlags struct {
 
 var flags *allFlags
 
-// VisitAll visits the flags, calling fn for each. It visits all flags, even those not set.
+// sortFlags returns the flags as a slice in lexicographical sorted order.
+func sortFlags(flags map[string]*Flag) []*Flag {
+       list := make(sort.StringArray, len(flags))
+       i := 0
+       for _, f := range flags {
+               list[i] = f.Name
+               i++
+       }
+       list.Sort()
+       result := make([]*Flag, len(list))
+       for i, name := range list {
+               result[i] = flags[name]
+       }
+       return result
+}
+
+// VisitAll visits the flags in lexicographical order, calling fn for each.
+// It visits all flags, even those not set.
 func VisitAll(fn func(*Flag)) {
-       for _, f := range flags.formal {
+       for _, f := range sortFlags(flags.formal) {
                fn(f)
        }
 }
 
-// Visit visits the flags, calling fn for each. It visits only those flags that have been set.
+// Visit visits the flags in lexicographical order, calling fn for each.
+// It visits only those flags that have been set.
 func Visit(fn func(*Flag)) {
-       for _, f := range flags.actual {
+       for _, f := range sortFlags(flags.actual) {
                fn(f)
        }
 }
@@ -260,7 +279,9 @@ var Usage = func() {
 
 var panicOnError = false
 
-func fail() {
+// failf prints to standard error a formatted error and Usage, and then exits the program.
+func failf(format string, a ...interface{}) {
+       fmt.Fprintf(os.Stderr, format, a...)
        Usage()
        if panicOnError {
                panic("flag parse error")
@@ -268,6 +289,7 @@ func fail() {
        os.Exit(2)
 }
 
+// NFlag returns the number of flags that have been set.
 func NFlag() int { return len(flags.actual) }
 
 // Arg returns the i'th command-line argument.  Arg(0) is the first remaining argument
@@ -415,8 +437,7 @@ func (f *allFlags) parseOne() (ok bool) {
        }
        name := s[num_minuses:]
        if len(name) == 0 || name[0] == '-' || name[0] == '=' {
-               fmt.Fprintln(os.Stderr, "bad flag syntax:", s)
-               fail()
+               failf("bad flag syntax: %s\n", s)
        }
 
        // it's a flag. does it have an argument?
@@ -434,14 +455,12 @@ func (f *allFlags) parseOne() (ok bool) {
        m := flags.formal
        flag, alreadythere := m[name] // BUG
        if !alreadythere {
-               fmt.Fprintf(os.Stderr, "flag provided but not defined: -%s\n", name)
-               fail()
+               failf("flag provided but not defined: -%s\n", name)
        }
        if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
                if has_value {
                        if !fv.Set(value) {
-                               fmt.Fprintf(os.Stderr, "invalid boolean value %q for flag: -%s\n", value, name)
-                               fail()
+                               failf("invalid boolean value %q for flag: -%s\n", value, name)
                        }
                } else {
                        fv.Set("true")
@@ -454,13 +473,11 @@ func (f *allFlags) parseOne() (ok bool) {
                        value, f.args = f.args[0], f.args[1:]
                }
                if !has_value {
-                       fmt.Fprintf(os.Stderr, "flag needs an argument: -%s\n", name)
-                       fail()
+                       failf("flag needs an argument: -%s\n", name)
                }
                ok = flag.Value.Set(value)
                if !ok {
-                       fmt.Fprintf(os.Stderr, "invalid value %q for flag: -%s\n", value, name)
-                       fail()
+                       failf("invalid value %q for flag: -%s\n", value, name)
                }
        }
        flags.actual[name] = flag
index 30a21e61ae2d46d91be9f5e887c0d633edd46505..1e47d12e48a76ac07c151da89426e267ae3911d5 100644 (file)
@@ -8,6 +8,7 @@ import (
        . "flag"
        "fmt"
        "os"
+       "sort"
        "testing"
 )
 
@@ -77,6 +78,12 @@ func TestEverything(t *testing.T) {
                        t.Log(k, *v)
                }
        }
+       // Now test they're visited in sort order.
+       var flagNames []string
+       Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
+       if !sort.StringsAreSorted(flagNames) {
+               t.Errorf("flag names not sorted: %v", flagNames)
+       }
 }
 
 func TestUsage(t *testing.T) {
index 86057bf693cda2d49cd8ce44fd41407ac357045a..caaa7ac1a8af36ed2fb0ac13ef69911f33971954 100644 (file)
@@ -107,7 +107,7 @@ func (f *fmt) writePadding(n int, padding []byte) {
 }
 
 // Append b to f.buf, padded on left (w > 0) or right (w < 0 or f.minus)
-// clear flags aftewards.
+// clear flags afterwards.
 func (f *fmt) pad(b []byte) {
        var padding []byte
        var left, right int
@@ -124,7 +124,7 @@ func (f *fmt) pad(b []byte) {
 }
 
 // append s to buf, padded on left (w > 0) or right (w < 0 or f.minus).
-// clear flags aftewards.
+// clear flags afterwards.
 func (f *fmt) padString(s string) {
        var padding []byte
        var left, right int
index c0f2bacb69be934e19d14fccfb5507820b229ae0..36271a8d4665ca404d0c698c82cade78b0622065 100644 (file)
@@ -35,10 +35,15 @@ type ScanState interface {
        ReadRune() (rune int, size int, err os.Error)
        // UnreadRune causes the next call to ReadRune to return the same rune.
        UnreadRune() os.Error
-       // Token returns the next space-delimited token from the input. If
-       // a width has been specified, the returned token will be no longer
-       // than the width.
-       Token() (token string, err os.Error)
+       // Token skips space in the input if skipSpace is true, then returns the
+       // run of Unicode code points c satisfying f(c).  If f is nil,
+       // !unicode.IsSpace(c) is used; that is, the token will hold non-space
+       // characters.  Newlines are treated as space unless the scan operation
+       // is Scanln, Fscanln or Sscanln, in which case a newline is treated as
+       // EOF.  The returned slice points to shared data that may be overwritten
+       // by the next call to Token, a call to a Scan function using the ScanState
+       // as input, or when the calling Scan method returns.
+       Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
        // Width returns the value of the width option and whether it has been set.
        // The unit is Unicode code points.
        Width() (wid int, ok bool)
@@ -134,7 +139,7 @@ type scanError struct {
        err os.Error
 }
 
-const EOF = -1
+const eof = -1
 
 // ss is the internal implementation of ScanState.
 type ss struct {
@@ -202,7 +207,7 @@ func (s *ss) getRune() (rune int) {
        rune, _, err := s.ReadRune()
        if err != nil {
                if err == os.EOF {
-                       return EOF
+                       return eof
                }
                s.error(err)
        }
@@ -214,7 +219,7 @@ func (s *ss) getRune() (rune int) {
 // syntax error.
 func (s *ss) mustReadRune() (rune int) {
        rune = s.getRune()
-       if rune == EOF {
+       if rune == eof {
                s.error(io.ErrUnexpectedEOF)
        }
        return
@@ -238,7 +243,7 @@ func (s *ss) errorString(err string) {
        panic(scanError{os.ErrorString(err)})
 }
 
-func (s *ss) Token() (tok string, err os.Error) {
+func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
        defer func() {
                if e := recover(); e != nil {
                        if se, ok := e.(scanError); ok {
@@ -248,10 +253,19 @@ func (s *ss) Token() (tok string, err os.Error) {
                        }
                }
        }()
-       tok = s.token()
+       if f == nil {
+               f = notSpace
+       }
+       s.buf.Reset()
+       tok = s.token(skipSpace, f)
        return
 }
 
+// notSpace is the default scanning function used in Token.
+func notSpace(r int) bool {
+       return !unicode.IsSpace(r)
+}
+
 // readRune is a structure to enable reading UTF-8 encoded code points
 // from an io.Reader.  It is used if the Reader given to the scanner does
 // not already implement io.RuneReader.
@@ -364,7 +378,7 @@ func (s *ss) free(old ssave) {
 func (s *ss) skipSpace(stopAtNewline bool) {
        for {
                rune := s.getRune()
-               if rune == EOF {
+               if rune == eof {
                        return
                }
                if rune == '\n' {
@@ -384,24 +398,27 @@ func (s *ss) skipSpace(stopAtNewline bool) {
        }
 }
 
+
 // token returns the next space-delimited string from the input.  It
 // skips white space.  For Scanln, it stops at newlines.  For Scan,
 // newlines are treated as spaces.
-func (s *ss) token() string {
-       s.skipSpace(false)
+func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
+       if skipSpace {
+               s.skipSpace(false)
+       }
        // read until white space or newline
        for {
                rune := s.getRune()
-               if rune == EOF {
+               if rune == eof {
                        break
                }
-               if unicode.IsSpace(rune) {
+               if !f(rune) {
                        s.UnreadRune()
                        break
                }
                s.buf.WriteRune(rune)
        }
-       return s.buf.String()
+       return s.buf.Bytes()
 }
 
 // typeError indicates that the type of the operand did not match the format
@@ -416,7 +433,7 @@ var boolError = os.ErrorString("syntax error scanning boolean")
 // If accept is true, it puts the character into the input token.
 func (s *ss) consume(ok string, accept bool) bool {
        rune := s.getRune()
-       if rune == EOF {
+       if rune == eof {
                return false
        }
        if strings.IndexRune(ok, rune) >= 0 {
@@ -425,7 +442,7 @@ func (s *ss) consume(ok string, accept bool) bool {
                }
                return true
        }
-       if rune != EOF && accept {
+       if rune != eof && accept {
                s.UnreadRune()
        }
        return false
@@ -434,7 +451,7 @@ func (s *ss) consume(ok string, accept bool) bool {
 // peek reports whether the next character is in the ok string, without consuming it.
 func (s *ss) peek(ok string) bool {
        rune := s.getRune()
-       if rune != EOF {
+       if rune != eof {
                s.UnreadRune()
        }
        return strings.IndexRune(ok, rune) >= 0
@@ -729,7 +746,7 @@ func (s *ss) convertString(verb int) (str string) {
        case 'x':
                str = s.hexString()
        default:
-               str = s.token() // %s and %v just return the next word
+               str = string(s.token(true, notSpace)) // %s and %v just return the next word
        }
        // Empty strings other than with %q are not OK.
        if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
@@ -797,7 +814,7 @@ func (s *ss) hexDigit(digit int) int {
 // There must be either two hexadecimal digits or a space character in the input.
 func (s *ss) hexByte() (b byte, ok bool) {
        rune1 := s.getRune()
-       if rune1 == EOF {
+       if rune1 == eof {
                return
        }
        if unicode.IsSpace(rune1) {
@@ -953,7 +970,7 @@ func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
        if !s.nlIsSpace {
                for {
                        rune := s.getRune()
-                       if rune == '\n' || rune == EOF {
+                       if rune == '\n' || rune == eof {
                                break
                        }
                        if !unicode.IsSpace(rune) {
@@ -993,7 +1010,7 @@ func (s *ss) advance(format string) (i int) {
                        // There was space in the format, so there should be space (EOF)
                        // in the input.
                        inputc := s.getRune()
-                       if inputc == EOF {
+                       if inputc == eof {
                                return
                        }
                        if !unicode.IsSpace(inputc) {
index 65adb023686e8323a46e2612966b0633b59c5b4a..8d2e6f5c64e86448d8fe6e16d6112b794492771e 100644 (file)
@@ -88,14 +88,15 @@ type FloatTest struct {
 type Xs string
 
 func (x *Xs) Scan(state ScanState, verb int) os.Error {
-       tok, err := state.Token()
+       tok, err := state.Token(true, func(r int) bool { return r == verb })
        if err != nil {
                return err
        }
-       if !regexp.MustCompile("^" + string(verb) + "+$").MatchString(tok) {
+       s := string(tok)
+       if !regexp.MustCompile("^" + string(verb) + "+$").MatchString(s) {
                return os.ErrorString("syntax error for xs")
        }
-       *x = Xs(tok)
+       *x = Xs(s)
        return nil
 }
 
@@ -113,9 +114,11 @@ func (s *IntString) Scan(state ScanState, verb int) os.Error {
                return err
        }
 
-       if _, err := Fscan(state, &s.s); err != nil {
+       tok, err := state.Token(true, nil)
+       if err != nil {
                return err
        }
+       s.s = string(tok)
        return nil
 }
 
@@ -331,7 +334,7 @@ var multiTests = []ScanfMultiTest{
        {"%c%c%c", "2\u50c2X", args(&i, &j, &k), args('2', '\u50c2', 'X'), ""},
 
        // Custom scanners.
-       {"%2e%f", "eefffff", args(&x, &y), args(Xs("ee"), Xs("fffff")), ""},
+       {"%e%f", "eefffff", args(&x, &y), args(Xs("ee"), Xs("fffff")), ""},
        {"%4v%s", "12abcd", args(&z, &s), args(IntString{12, "ab"}, "cd"), ""},
 
        // Errors
@@ -476,22 +479,12 @@ func verifyInf(str string, t *testing.T) {
        }
 }
 
-
 func TestInf(t *testing.T) {
        for _, s := range []string{"inf", "+inf", "-inf", "INF", "-INF", "+INF", "Inf", "-Inf", "+Inf"} {
                verifyInf(s, t)
        }
 }
 
-// TODO: there's no conversion from []T to ...T, but we can fake it.  These
-// functions do the faking.  We index the table by the length of the param list.
-var fscanf = []func(io.Reader, string, []interface{}) (int, os.Error){
-       0: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f) },
-       1: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0]) },
-       2: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0], i[1]) },
-       3: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0], i[1], i[2]) },
-}
-
 func testScanfMulti(name string, t *testing.T) {
        sliceType := reflect.Typeof(make([]interface{}, 1)).(*reflect.SliceType)
        for _, test := range multiTests {
@@ -501,7 +494,7 @@ func testScanfMulti(name string, t *testing.T) {
                } else {
                        r = newReader(test.text)
                }
-               n, err := fscanf[len(test.in)](r, test.format, test.in)
+               n, err := Fscanf(r, test.format, test.in...)
                if err != nil {
                        if test.err == "" {
                                t.Errorf("got error scanning (%q, %q): %q", test.format, test.text, err)
@@ -830,12 +823,12 @@ func testScanInts(t *testing.T, scan func(*RecursiveInt, *bytes.Buffer) os.Error
        i := 1
        for ; r != nil; r = r.next {
                if r.i != i {
-                       t.Fatal("bad scan: expected %d got %d", i, r.i)
+                       t.Fatalf("bad scan: expected %d got %d", i, r.i)
                }
                i++
        }
        if i-1 != intCount {
-               t.Fatal("bad scan count: expected %d got %d", intCount, i-1)
+               t.Fatalf("bad scan count: expected %d got %d", intCount, i-1)
        }
 }
 
index abafb5663b3baf0ef0dcfdebfc9d303be68a6ae1..4a4c12b7c0a2065660584f509e5f0228183a5c51 100644 (file)
@@ -602,12 +602,12 @@ type (
                Else Stmt // else branch; or nil
        }
 
-       // A CaseClause represents a case of an expression switch statement.
+       // A CaseClause represents a case of an expression or type switch statement.
        CaseClause struct {
-               Case   token.Pos // position of "case" or "default" keyword
-               Values []Expr    // nil means default case
-               Colon  token.Pos // position of ":"
-               Body   []Stmt    // statement list; or nil
+               Case  token.Pos // position of "case" or "default" keyword
+               List  []Expr    // list of expressions or types; nil means default case
+               Colon token.Pos // position of ":"
+               Body  []Stmt    // statement list; or nil
        }
 
        // A SwitchStmt node represents an expression switch statement.
@@ -618,20 +618,12 @@ type (
                Body   *BlockStmt // CaseClauses only
        }
 
-       // A TypeCaseClause represents a case of a type switch statement.
-       TypeCaseClause struct {
-               Case  token.Pos // position of "case" or "default" keyword
-               Types []Expr    // nil means default case
-               Colon token.Pos // position of ":"
-               Body  []Stmt    // statement list; or nil
-       }
-
        // An TypeSwitchStmt node represents a type switch statement.
        TypeSwitchStmt struct {
                Switch token.Pos  // position of "switch" keyword
                Init   Stmt       // initalization statement; or nil
-               Assign Stmt       // x := y.(type)
-               Body   *BlockStmt // TypeCaseClauses only
+               Assign Stmt       // x := y.(type) or y.(type)
+               Body   *BlockStmt // CaseClauses only
        }
 
        // A CommClause node represents a case of a select statement.
@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos      { return s.Lbrace }
 func (s *IfStmt) Pos() token.Pos         { return s.If }
 func (s *CaseClause) Pos() token.Pos     { return s.Case }
 func (s *SwitchStmt) Pos() token.Pos     { return s.Switch }
-func (s *TypeCaseClause) Pos() token.Pos { return s.Case }
 func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch }
 func (s *CommClause) Pos() token.Pos     { return s.Case }
 func (s *SelectStmt) Pos() token.Pos     { return s.Select }
@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos {
        }
        return s.Colon + 1
 }
-func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
-func (s *TypeCaseClause) End() token.Pos {
-       if n := len(s.Body); n > 0 {
-               return s.Body[n-1].End()
-       }
-       return s.Colon + 1
-}
+func (s *SwitchStmt) End() token.Pos     { return s.Body.End() }
 func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() }
 func (s *CommClause) End() token.Pos {
        if n := len(s.Body); n > 0 {
@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode()      {}
 func (s *IfStmt) stmtNode()         {}
 func (s *CaseClause) stmtNode()     {}
 func (s *SwitchStmt) stmtNode()     {}
-func (s *TypeCaseClause) stmtNode() {}
 func (s *TypeSwitchStmt) stmtNode() {}
 func (s *CommClause) stmtNode()     {}
 func (s *SelectStmt) stmtNode()     {}
@@ -937,11 +921,13 @@ func (d *FuncDecl) declNode() {}
 // via Doc and Comment fields.
 //
 type File struct {
-       Doc      *CommentGroup   // associated documentation; or nil
-       Package  token.Pos       // position of "package" keyword
-       Name     *Ident          // package name
-       Decls    []Decl          // top-level declarations; or nil
-       Comments []*CommentGroup // list of all comments in the source file
+       Doc        *CommentGroup   // associated documentation; or nil
+       Package    token.Pos       // position of "package" keyword
+       Name       *Ident          // package name
+       Decls      []Decl          // top-level declarations; or nil
+       Scope      *Scope          // package scope
+       Unresolved []*Ident        // unresolved global identifiers
+       Comments   []*CommentGroup // list of all comments in the source file
 }
 
 
@@ -959,7 +945,7 @@ func (f *File) End() token.Pos {
 //
 type Package struct {
        Name  string           // package name
-       Scope *Scope           // package scope; or nil
+       Scope *Scope           // package scope
        Files map[string]*File // Go source files by filename
 }
 
index 0c3cef4b27b15c976c8642c409d7a694f7d53c68..4da487ce02ad02fd7c508b64809537f6543edbc3 100644 (file)
@@ -425,5 +425,6 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
                }
        }
 
-       return &File{doc, pos, NewIdent(pkg.Name), decls, comments}
+       // TODO(gri) need to compute pkgScope and unresolved identifiers!
+       return &File{doc, pos, NewIdent(pkg.Name), decls, nil, nil, comments}
 }
index d71490d4a9ff440bb2c032987894ab0d0c4d50d9..82c334ece67504ac4d21523db72c9439a4c71173 100644 (file)
@@ -30,15 +30,19 @@ func NotNilFilter(_ string, value reflect.Value) bool {
 
 
 // Fprint prints the (sub-)tree starting at AST node x to w.
+// If fset != nil, position information is interpreted relative
+// to that file set. Otherwise positions are printed as integer
+// values (file set specific offsets).
 //
 // A non-nil FieldFilter f may be provided to control the output:
 // struct fields for which f(fieldname, fieldvalue) is true are
 // are printed; all others are filtered from the output.
 //
-func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) {
+func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n int, err os.Error) {
        // setup printer
        p := printer{
                output: w,
+               fset:   fset,
                filter: f,
                ptrmap: make(map[interface{}]int),
                last:   '\n', // force printing of line number on first line
@@ -65,14 +69,15 @@ func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) {
 
 
 // Print prints x to standard output, skipping nil fields.
-// Print(x) is the same as Fprint(os.Stdout, x, NotNilFilter).
-func Print(x interface{}) (int, os.Error) {
-       return Fprint(os.Stdout, x, NotNilFilter)
+// Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter).
+func Print(fset *token.FileSet, x interface{}) (int, os.Error) {
+       return Fprint(os.Stdout, fset, x, NotNilFilter)
 }
 
 
 type printer struct {
        output  io.Writer
+       fset    *token.FileSet
        filter  FieldFilter
        ptrmap  map[interface{}]int // *reflect.PtrValue -> line number
        written int                 // number of bytes written to output
@@ -137,16 +142,6 @@ func (p *printer) printf(format string, args ...interface{}) {
 // probably be in a different package.
 
 func (p *printer) print(x reflect.Value) {
-       // Note: This test is only needed because AST nodes
-       //       embed a token.Position, and thus all of them
-       //       understand the String() method (but it only
-       //       applies to the Position field).
-       // TODO: Should reconsider this AST design decision.
-       if pos, ok := x.Interface().(token.Position); ok {
-               p.printf("%s", pos)
-               return
-       }
-
        if !NotNilFilter("", x) {
                p.printf("nil")
                return
@@ -163,6 +158,7 @@ func (p *printer) print(x reflect.Value) {
                        p.print(key)
                        p.printf(": ")
                        p.print(v.Elem(key))
+                       p.printf("\n")
                }
                p.indent--
                p.printf("}")
@@ -212,6 +208,11 @@ func (p *printer) print(x reflect.Value) {
                p.printf("}")
 
        default:
-               p.printf("%v", x.Interface())
+               value := x.Interface()
+               // position values can be printed nicely if we have a file set
+               if pos, ok := value.(token.Pos); ok && p.fset != nil {
+                       value = p.fset.Position(pos)
+               }
+               p.printf("%v", value)
        }
 }
index 956a208aede5aa3e24d482dc879bbf7373e230e0..91866dcf57b4a4194e1c42e112a40b8e0ad1b9c4 100644 (file)
@@ -2,31 +2,31 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// This file implements scopes, the objects they contain,
-// and object types.
+// This file implements scopes and the objects they contain.
 
 package ast
 
+import (
+       "bytes"
+       "fmt"
+       "go/token"
+)
+
+
 // A Scope maintains the set of named language entities declared
 // in the scope and a link to the immediately surrounding (outer)
 // scope.
 //
 type Scope struct {
        Outer   *Scope
-       Objects []*Object // in declaration order
-       // Implementation note: In some cases (struct fields,
-       // function parameters) we need the source order of
-       // variables. Thus for now, we store scope entries
-       // in a linear list. If scopes become very large
-       // (say, for packages), we may need to change this
-       // to avoid slow lookups.
+       Objects map[string]*Object
 }
 
 
 // NewScope creates a new scope nested in the outer scope.
 func NewScope(outer *Scope) *Scope {
-       const n = 4 // initial scope capacity, must be > 0
-       return &Scope{outer, make([]*Object, 0, n)}
+       const n = 4 // initial scope capacity
+       return &Scope{outer, make(map[string]*Object, n)}
 }
 
 
@@ -34,73 +34,108 @@ func NewScope(outer *Scope) *Scope {
 // found in scope s, otherwise it returns nil. Outer scopes
 // are ignored.
 //
-// Lookup always returns nil if name is "_", even if the scope
-// contains objects with that name.
-//
 func (s *Scope) Lookup(name string) *Object {
-       if name != "_" {
-               for _, obj := range s.Objects {
-                       if obj.Name == name {
-                               return obj
-                       }
-               }
-       }
-       return nil
+       return s.Objects[name]
 }
 
 
 // Insert attempts to insert a named object into the scope s.
-// If the scope does not contain an object with that name yet
-// or if the object is named "_", Insert inserts the object
-// and returns it. Otherwise, Insert leaves the scope unchanged
-// and returns the object found in the scope instead.
+// If the scope does not contain an object with that name yet,
+// Insert inserts the object and returns it. Otherwise, Insert
+// leaves the scope unchanged and returns the object found in
+// the scope instead.
 //
-func (s *Scope) Insert(obj *Object) *Object {
-       alt := s.Lookup(obj.Name)
-       if alt == nil {
-               s.append(obj)
+func (s *Scope) Insert(obj *Object) (alt *Object) {
+       if alt = s.Objects[obj.Name]; alt == nil {
+               s.Objects[obj.Name] = obj
                alt = obj
        }
-       return alt
+       return
 }
 
 
-func (s *Scope) append(obj *Object) {
-       s.Objects = append(s.Objects, obj)
+// Debugging support
+func (s *Scope) String() string {
+       var buf bytes.Buffer
+       fmt.Fprintf(&buf, "scope %p {", s)
+       if s != nil && len(s.Objects) > 0 {
+               fmt.Fprintln(&buf)
+               for _, obj := range s.Objects {
+                       fmt.Fprintf(&buf, "\t%s %s\n", obj.Kind, obj.Name)
+               }
+       }
+       fmt.Fprintf(&buf, "}\n")
+       return buf.String()
 }
 
+
 // ----------------------------------------------------------------------------
 // Objects
 
-// An Object describes a language entity such as a package,
-// constant, type, variable, or function (incl. methods).
+// An Object describes a named language entity such as a package,
+// constant, type, variable, function (incl. methods), or label.
 //
 type Object struct {
-       Kind Kind
-       Name string // declared name
-       Type *Type
-       Decl interface{} // corresponding Field, XxxSpec or FuncDecl
-       N    int         // value of iota for this declaration
+       Kind ObjKind
+       Name string      // declared name
+       Decl interface{} // corresponding Field, XxxSpec, FuncDecl, or LabeledStmt; or nil
+       Type interface{} // place holder for type information; may be nil
 }
 
 
 // NewObj creates a new object of a given kind and name.
-func NewObj(kind Kind, name string) *Object {
+func NewObj(kind ObjKind, name string) *Object {
        return &Object{Kind: kind, Name: name}
 }
 
 
-// Kind describes what an object represents.
-type Kind int
+// Pos computes the source position of the declaration of an object name.
+// The result may be an invalid position if it cannot be computed
+// (obj.Decl may be nil or not correct).
+func (obj *Object) Pos() token.Pos {
+       name := obj.Name
+       switch d := obj.Decl.(type) {
+       case *Field:
+               for _, n := range d.Names {
+                       if n.Name == name {
+                               return n.Pos()
+                       }
+               }
+       case *ValueSpec:
+               for _, n := range d.Names {
+                       if n.Name == name {
+                               return n.Pos()
+                       }
+               }
+       case *TypeSpec:
+               if d.Name.Name == name {
+                       return d.Name.Pos()
+               }
+       case *FuncDecl:
+               if d.Name.Name == name {
+                       return d.Name.Pos()
+               }
+       case *LabeledStmt:
+               if d.Label.Name == name {
+                       return d.Label.Pos()
+               }
+       }
+       return token.NoPos
+}
+
+
+// ObKind describes what an object represents.
+type ObjKind int
 
 // The list of possible Object kinds.
 const (
-       Bad Kind = iota // for error handling
-       Pkg             // package
-       Con             // constant
-       Typ             // type
-       Var             // variable
-       Fun             // function or method
+       Bad ObjKind = iota // for error handling
+       Pkg                // package
+       Con                // constant
+       Typ                // type
+       Var                // variable
+       Fun                // function or method
+       Lbl                // label
 )
 
 
@@ -111,132 +146,8 @@ var objKindStrings = [...]string{
        Typ: "type",
        Var: "var",
        Fun: "func",
+       Lbl: "label",
 }
 
 
-func (kind Kind) String() string { return objKindStrings[kind] }
-
-
-// IsExported returns whether obj is exported.
-func (obj *Object) IsExported() bool { return IsExported(obj.Name) }
-
-
-// ----------------------------------------------------------------------------
-// Types
-
-// A Type represents a Go type.
-type Type struct {
-       Form     Form
-       Obj      *Object // corresponding type name, or nil
-       Scope    *Scope  // fields and methods, always present
-       N        uint    // basic type id, array length, number of function results, or channel direction
-       Key, Elt *Type   // map key and array, pointer, slice, map or channel element
-       Params   *Scope  // function (receiver, input and result) parameters, tuple expressions (results of function calls), or nil
-       Expr     Expr    // corresponding AST expression
-}
-
-
-// NewType creates a new type of a given form.
-func NewType(form Form) *Type {
-       return &Type{Form: form, Scope: NewScope(nil)}
-}
-
-
-// Form describes the form of a type.
-type Form int
-
-// The list of possible type forms.
-const (
-       BadType    Form = iota // for error handling
-       Unresolved             // type not fully setup
-       Basic
-       Array
-       Struct
-       Pointer
-       Function
-       Method
-       Interface
-       Slice
-       Map
-       Channel
-       Tuple
-)
-
-
-var formStrings = [...]string{
-       BadType:    "badType",
-       Unresolved: "unresolved",
-       Basic:      "basic",
-       Array:      "array",
-       Struct:     "struct",
-       Pointer:    "pointer",
-       Function:   "function",
-       Method:     "method",
-       Interface:  "interface",
-       Slice:      "slice",
-       Map:        "map",
-       Channel:    "channel",
-       Tuple:      "tuple",
-}
-
-
-func (form Form) String() string { return formStrings[form] }
-
-
-// The list of basic type id's.
-const (
-       Bool = iota
-       Byte
-       Uint
-       Int
-       Float
-       Complex
-       Uintptr
-       String
-
-       Uint8
-       Uint16
-       Uint32
-       Uint64
-
-       Int8
-       Int16
-       Int32
-       Int64
-
-       Float32
-       Float64
-
-       Complex64
-       Complex128
-
-       // TODO(gri) ideal types are missing
-)
-
-
-var BasicTypes = map[uint]string{
-       Bool:    "bool",
-       Byte:    "byte",
-       Uint:    "uint",
-       Int:     "int",
-       Float:   "float",
-       Complex: "complex",
-       Uintptr: "uintptr",
-       String:  "string",
-
-       Uint8:  "uint8",
-       Uint16: "uint16",
-       Uint32: "uint32",
-       Uint64: "uint64",
-
-       Int8:  "int8",
-       Int16: "int16",
-       Int32: "int32",
-       Int64: "int64",
-
-       Float32: "float32",
-       Float64: "float64",
-
-       Complex64:  "complex64",
-       Complex128: "complex128",
-}
+func (kind ObjKind) String() string { return objKindStrings[kind] }
index 20c337c3be98b8f469da3f9d42f29979f03a78a7..95c4b3a3564948379abd558473bd66dffc80491f 100644 (file)
@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) {
                }
 
        case *CaseClause:
-               walkExprList(v, n.Values)
+               walkExprList(v, n.List)
                walkStmtList(v, n.Body)
 
        case *SwitchStmt:
@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) {
                }
                Walk(v, n.Body)
 
-       case *TypeCaseClause:
-               for _, x := range n.Types {
-                       Walk(v, x)
-               }
-               walkStmtList(v, n.Body)
-
        case *TypeSwitchStmt:
                if n.Init != nil {
                        Walk(v, n.Init)
index 84d699a67935910de4159c7f0b0ffc2be92ecade..6f35b495efa15234c2a1b06e505031e44ae5edca 100644 (file)
@@ -14,7 +14,7 @@ import (
        "io"
        "io/ioutil"
        "os"
-       pathutil "path"
+       "path/filepath"
 )
 
 
@@ -198,7 +198,7 @@ func ParseDir(fset *token.FileSet, path string, filter func(*os.FileInfo) bool,
        for i := 0; i < len(list); i++ {
                d := &list[i]
                if filter == nil || filter(d) {
-                       filenames[n] = pathutil.Join(path, d.Name)
+                       filenames[n] = filepath.Join(path, d.Name)
                        n++
                }
        }
index 7c5843f36371a19b23b476edf985bc13f666a633..b0e8c8ad7a8e86f6674e083962d40cd4b835b4f2 100644 (file)
@@ -17,10 +17,6 @@ import (
 )
 
 
-// noPos is used when there is no corresponding source position for a token.
-var noPos token.Position
-
-
 // The mode parameter to the Parse* functions is a set of flags (or 0).
 // They control the amount of source code parsed and other optional
 // parser functionality.
@@ -30,6 +26,7 @@ const (
        ImportsOnly                        // parsing stops after import declarations
        ParseComments                      // parse comments and add them to AST
        Trace                              // print a trace of parsed productions
+       DeclarationErrors                  // report declaration errors
 )
 
 
@@ -46,16 +43,26 @@ type parser struct {
 
        // Comments
        comments    []*ast.CommentGroup
-       leadComment *ast.CommentGroup // the last lead comment
-       lineComment *ast.CommentGroup // the last line comment
+       leadComment *ast.CommentGroup // last lead comment
+       lineComment *ast.CommentGroup // last line comment
 
        // Next token
-       pos token.Pos   // token position
-       tok token.Token // one token look-ahead
-       lit []byte      // token literal
+       pos  token.Pos   // token position
+       tok  token.Token // one token look-ahead
+       lit_ []byte      // token literal (slice into original source, don't hold on to it)
 
        // Non-syntactic parser control
        exprLev int // < 0: in control clause, >= 0: in expression
+
+       // Ordinary identifer scopes
+       pkgScope   *ast.Scope   // pkgScope.Outer == nil
+       topScope   *ast.Scope   // top-most scope; may be pkgScope
+       unresolved []*ast.Ident // unresolved global identifiers
+
+       // Label scope
+       // (maintained by open/close LabelScope)
+       labelScope  *ast.Scope     // label scope for current function
+       targetStack [][]*ast.Ident // stack of unresolved labels
 }
 
 
@@ -72,9 +79,126 @@ func scannerMode(mode uint) uint {
 func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode uint) {
        p.file = fset.AddFile(filename, fset.Base(), len(src))
        p.scanner.Init(p.file, src, p, scannerMode(mode))
+
        p.mode = mode
        p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently)
+
        p.next()
+
+       // set up the pkgScope here (as opposed to in parseFile) because
+       // there are other parser entry points (ParseExpr, etc.)
+       p.openScope()
+       p.pkgScope = p.topScope
+
+       // for the same reason, set up a label scope
+       p.openLabelScope()
+}
+
+
+func (p *parser) lit() []byte {
+       // make a copy of p.lit_ so that we don't hold on to
+       // a copy of the entire source indirectly in the AST
+       t := make([]byte, len(p.lit_))
+       copy(t, p.lit_)
+       return t
+}
+
+
+// ----------------------------------------------------------------------------
+// Scoping support
+
+func (p *parser) openScope() {
+       p.topScope = ast.NewScope(p.topScope)
+}
+
+
+func (p *parser) closeScope() {
+       p.topScope = p.topScope.Outer
+}
+
+
+func (p *parser) openLabelScope() {
+       p.labelScope = ast.NewScope(p.labelScope)
+       p.targetStack = append(p.targetStack, nil)
+}
+
+
+func (p *parser) closeLabelScope() {
+       // resolve labels
+       n := len(p.targetStack) - 1
+       scope := p.labelScope
+       for _, ident := range p.targetStack[n] {
+               ident.Obj = scope.Lookup(ident.Name)
+               if ident.Obj == nil && p.mode&DeclarationErrors != 0 {
+                       p.error(ident.Pos(), fmt.Sprintf("label %s undefined", ident.Name))
+               }
+       }
+       // pop label scope
+       p.targetStack = p.targetStack[0:n]
+       p.labelScope = p.labelScope.Outer
+}
+
+
+func (p *parser) declare(decl interface{}, scope *ast.Scope, kind ast.ObjKind, idents ...*ast.Ident) {
+       for _, ident := range idents {
+               if ident.Name != "_" {
+                       obj := ast.NewObj(kind, ident.Name)
+                       // remember the corresponding declaration for redeclaration
+                       // errors and global variable resolution/typechecking phase
+                       obj.Decl = decl
+                       alt := scope.Insert(obj)
+                       if alt != obj && p.mode&DeclarationErrors != 0 {
+                               prevDecl := ""
+                               if pos := alt.Pos(); pos.IsValid() {
+                                       prevDecl = fmt.Sprintf("\n\tprevious declaration at %s", p.file.Position(pos))
+                               }
+                               p.error(ident.Pos(), fmt.Sprintf("%s redeclared in this block%s", ident.Name, prevDecl))
+                       }
+                       ident.Obj = obj
+               }
+       }
+}
+
+
+func (p *parser) shortVarDecl(idents []*ast.Ident) {
+       // Go spec: A short variable declaration may redeclare variables
+       // provided they were originally declared in the same block with
+       // the same type, and at least one of the non-blank variables is new.
+       n := 0 // number of new variables
+       for _, ident := range idents {
+               if ident.Name != "_" {
+                       obj := ast.NewObj(ast.Var, ident.Name)
+                       // short var declarations cannot have redeclaration errors
+                       // and are not global => no need to remember the respective
+                       // declaration
+                       alt := p.topScope.Insert(obj)
+                       if alt == obj {
+                               n++ // new declaration
+                       }
+                       ident.Obj = alt
+               }
+       }
+       if n == 0 && p.mode&DeclarationErrors != 0 {
+               p.error(idents[0].Pos(), "no new variables on left side of :=")
+       }
+}
+
+
+func (p *parser) resolve(ident *ast.Ident) {
+       if ident.Name == "_" {
+               return
+       }
+       // try to resolve the identifier
+       for s := p.topScope; s != nil; s = s.Outer {
+               if obj := s.Lookup(ident.Name); obj != nil {
+                       ident.Obj = obj
+                       return
+               }
+       }
+       // collect unresolved global identifiers; ignore the others
+       if p.topScope == p.pkgScope {
+               p.unresolved = append(p.unresolved, ident)
+       }
 }
 
 
@@ -120,7 +244,7 @@ func (p *parser) next0() {
                s := p.tok.String()
                switch {
                case p.tok.IsLiteral():
-                       p.printTrace(s, string(p.lit))
+                       p.printTrace(s, string(p.lit_))
                case p.tok.IsOperator(), p.tok.IsKeyword():
                        p.printTrace("\"" + s + "\"")
                default:
@@ -128,7 +252,7 @@ func (p *parser) next0() {
                }
        }
 
-       p.pos, p.tok, p.lit = p.scanner.Scan()
+       p.pos, p.tok, p.lit_ = p.scanner.Scan()
 }
 
 // Consume a comment and return it and the line on which it ends.
@@ -136,15 +260,15 @@ func (p *parser) consumeComment() (comment *ast.Comment, endline int) {
        // /*-style comments may end on a different line than where they start.
        // Scan the comment for '\n' chars and adjust endline accordingly.
        endline = p.file.Line(p.pos)
-       if p.lit[1] == '*' {
-               for _, b := range p.lit {
+       if p.lit_[1] == '*' {
+               for _, b := range p.lit_ {
                        if b == '\n' {
                                endline++
                        }
                }
        }
 
-       comment = &ast.Comment{p.pos, p.lit}
+       comment = &ast.Comment{p.pos, p.lit()}
        p.next0()
 
        return
@@ -234,12 +358,12 @@ func (p *parser) errorExpected(pos token.Pos, msg string) {
        if pos == p.pos {
                // the error happened at the current position;
                // make the error message more specific
-               if p.tok == token.SEMICOLON && p.lit[0] == '\n' {
+               if p.tok == token.SEMICOLON && p.lit_[0] == '\n' {
                        msg += ", found newline"
                } else {
                        msg += ", found '" + p.tok.String() + "'"
                        if p.tok.IsLiteral() {
-                               msg += " " + string(p.lit)
+                               msg += " " + string(p.lit_)
                        }
                }
        }
@@ -271,7 +395,7 @@ func (p *parser) parseIdent() *ast.Ident {
        pos := p.pos
        name := "_"
        if p.tok == token.IDENT {
-               name = string(p.lit)
+               name = string(p.lit_)
                p.next()
        } else {
                p.expect(token.IDENT) // use expect() error handling
@@ -339,13 +463,16 @@ func (p *parser) parseQualifiedIdent() ast.Expr {
                defer un(trace(p, "QualifiedIdent"))
        }
 
-       var x ast.Expr = p.parseIdent()
+       ident := p.parseIdent()
+       p.resolve(ident)
+       var x ast.Expr = ident
        if p.tok == token.PERIOD {
                // first identifier is a package identifier
                p.next()
                sel := p.parseIdent()
                x = &ast.SelectorExpr{x, sel}
        }
+
        return x
 }
 
@@ -407,7 +534,7 @@ func (p *parser) parseFieldDecl() *ast.Field {
        // optional tag
        var tag *ast.BasicLit
        if p.tok == token.STRING {
-               tag = &ast.BasicLit{p.pos, p.tok, p.lit}
+               tag = &ast.BasicLit{p.pos, p.tok, p.lit()}
                p.next()
        }
 
@@ -426,7 +553,7 @@ func (p *parser) parseFieldDecl() *ast.Field {
                }
        }
 
-       p.expectSemi()
+       p.expectSemi() // call before accessing p.linecomment
 
        return &ast.Field{doc, idents, typ, tag, p.lineComment}
 }
@@ -519,7 +646,7 @@ func (p *parser) parseVarList(isParam bool) (list []ast.Expr, typ ast.Expr) {
 }
 
 
-func (p *parser) parseParameterList(ellipsisOk bool) (params []*ast.Field) {
+func (p *parser) parseParameterList(scope *ast.Scope, ellipsisOk bool) (params []*ast.Field) {
        if p.trace {
                defer un(trace(p, "ParameterList"))
        }
@@ -528,7 +655,11 @@ func (p *parser) parseParameterList(ellipsisOk bool) (params []*ast.Field) {
        if typ != nil {
                // IdentifierList Type
                idents := p.makeIdentList(list)
-               params = append(params, &ast.Field{nil, idents, typ, nil, nil})
+               field := &ast.Field{nil, idents, typ, nil, nil}
+               params = append(params, field)
+               // Go spec: The scope of an identifier denoting a function
+               // parameter or result variable is the function body.
+               p.declare(field, scope, ast.Var, idents...)
                if p.tok == token.COMMA {
                        p.next()
                }
@@ -536,7 +667,11 @@ func (p *parser) parseParameterList(ellipsisOk bool) (params []*ast.Field) {
                for p.tok != token.RPAREN && p.tok != token.EOF {
                        idents := p.parseIdentList()
                        typ := p.parseVarType(ellipsisOk)
-                       params = append(params, &ast.Field{nil, idents, typ, nil, nil})
+                       field := &ast.Field{nil, idents, typ, nil, nil}
+                       params = append(params, field)
+                       // Go spec: The scope of an identifier denoting a function
+                       // parameter or result variable is the function body.
+                       p.declare(field, scope, ast.Var, idents...)
                        if p.tok != token.COMMA {
                                break
                        }
@@ -555,7 +690,7 @@ func (p *parser) parseParameterList(ellipsisOk bool) (params []*ast.Field) {
 }
 
 
-func (p *parser) parseParameters(ellipsisOk bool) *ast.FieldList {
+func (p *parser) parseParameters(scope *ast.Scope, ellipsisOk bool) *ast.FieldList {
        if p.trace {
                defer un(trace(p, "Parameters"))
        }
@@ -563,7 +698,7 @@ func (p *parser) parseParameters(ellipsisOk bool) *ast.FieldList {
        var params []*ast.Field
        lparen := p.expect(token.LPAREN)
        if p.tok != token.RPAREN {
-               params = p.parseParameterList(ellipsisOk)
+               params = p.parseParameterList(scope, ellipsisOk)
        }
        rparen := p.expect(token.RPAREN)
 
@@ -571,13 +706,13 @@ func (p *parser) parseParameters(ellipsisOk bool) *ast.FieldList {
 }
 
 
-func (p *parser) parseResult() *ast.FieldList {
+func (p *parser) parseResult(scope *ast.Scope) *ast.FieldList {
        if p.trace {
                defer un(trace(p, "Result"))
        }
 
        if p.tok == token.LPAREN {
-               return p.parseParameters(false)
+               return p.parseParameters(scope, false)
        }
 
        typ := p.tryType()
@@ -591,27 +726,28 @@ func (p *parser) parseResult() *ast.FieldList {
 }
 
 
-func (p *parser) parseSignature() (params, results *ast.FieldList) {
+func (p *parser) parseSignature(scope *ast.Scope) (params, results *ast.FieldList) {
        if p.trace {
                defer un(trace(p, "Signature"))
        }
 
-       params = p.parseParameters(true)
-       results = p.parseResult()
+       params = p.parseParameters(scope, true)
+       results = p.parseResult(scope)
 
        return
 }
 
 
-func (p *parser) parseFuncType() *ast.FuncType {
+func (p *parser) parseFuncType() (*ast.FuncType, *ast.Scope) {
        if p.trace {
                defer un(trace(p, "FuncType"))
        }
 
        pos := p.expect(token.FUNC)
-       params, results := p.parseSignature()
+       scope := ast.NewScope(p.topScope) // function scope
+       params, results := p.parseSignature(scope)
 
-       return &ast.FuncType{pos, params, results}
+       return &ast.FuncType{pos, params, results}, scope
 }
 
 
@@ -627,13 +763,14 @@ func (p *parser) parseMethodSpec() *ast.Field {
        if ident, isIdent := x.(*ast.Ident); isIdent && p.tok == token.LPAREN {
                // method
                idents = []*ast.Ident{ident}
-               params, results := p.parseSignature()
+               scope := ast.NewScope(nil) // method scope
+               params, results := p.parseSignature(scope)
                typ = &ast.FuncType{token.NoPos, params, results}
        } else {
                // embedded interface
                typ = x
        }
-       p.expectSemi()
+       p.expectSemi() // call before accessing p.linecomment
 
        return &ast.Field{doc, idents, typ, nil, p.lineComment}
 }
@@ -706,7 +843,8 @@ func (p *parser) tryRawType(ellipsisOk bool) ast.Expr {
        case token.MUL:
                return p.parsePointerType()
        case token.FUNC:
-               return p.parseFuncType()
+               typ, _ := p.parseFuncType()
+               return typ
        case token.INTERFACE:
                return p.parseInterfaceType()
        case token.MAP:
@@ -745,13 +883,17 @@ func (p *parser) parseStmtList() (list []ast.Stmt) {
 }
 
 
-func (p *parser) parseBody() *ast.BlockStmt {
+func (p *parser) parseBody(scope *ast.Scope) *ast.BlockStmt {
        if p.trace {
                defer un(trace(p, "Body"))
        }
 
        lbrace := p.expect(token.LBRACE)
+       p.topScope = scope // open function scope
+       p.openLabelScope()
        list := p.parseStmtList()
+       p.closeLabelScope()
+       p.closeScope()
        rbrace := p.expect(token.RBRACE)
 
        return &ast.BlockStmt{lbrace, list, rbrace}
@@ -764,7 +906,9 @@ func (p *parser) parseBlockStmt() *ast.BlockStmt {
        }
 
        lbrace := p.expect(token.LBRACE)
+       p.openScope()
        list := p.parseStmtList()
+       p.closeScope()
        rbrace := p.expect(token.RBRACE)
 
        return &ast.BlockStmt{lbrace, list, rbrace}
@@ -779,14 +923,14 @@ func (p *parser) parseFuncTypeOrLit() ast.Expr {
                defer un(trace(p, "FuncTypeOrLit"))
        }
 
-       typ := p.parseFuncType()
+       typ, scope := p.parseFuncType()
        if p.tok != token.LBRACE {
                // function type only
                return typ
        }
 
        p.exprLev++
-       body := p.parseBody()
+       body := p.parseBody(scope)
        p.exprLev--
 
        return &ast.FuncLit{typ, body}
@@ -803,10 +947,12 @@ func (p *parser) parseOperand() ast.Expr {
 
        switch p.tok {
        case token.IDENT:
-               return p.parseIdent()
+               ident := p.parseIdent()
+               p.resolve(ident)
+               return ident
 
        case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING:
-               x := &ast.BasicLit{p.pos, p.tok, p.lit}
+               x := &ast.BasicLit{p.pos, p.tok, p.lit()}
                p.next()
                return x
 
@@ -1202,6 +1348,9 @@ func (p *parser) parseSimpleStmt(labelOk bool) ast.Stmt {
                pos, tok := p.pos, p.tok
                p.next()
                y := p.parseExprList()
+               if tok == token.DEFINE {
+                       p.shortVarDecl(p.makeIdentList(x))
+               }
                return &ast.AssignStmt{x, pos, tok, y}
        }
 
@@ -1216,7 +1365,12 @@ func (p *parser) parseSimpleStmt(labelOk bool) ast.Stmt {
                colon := p.pos
                p.next()
                if label, isIdent := x[0].(*ast.Ident); labelOk && isIdent {
-                       return &ast.LabeledStmt{label, colon, p.parseStmt()}
+                       // Go spec: The scope of a label is the body of the function
+                       // in which it is declared and excludes the body of any nested
+                       // function.
+                       stmt := &ast.LabeledStmt{label, colon, p.parseStmt()}
+                       p.declare(stmt, p.labelScope, ast.Lbl, label)
+                       return stmt
                }
                p.error(x[0].Pos(), "illegal label declaration")
                return &ast.BadStmt{x[0].Pos(), colon + 1}
@@ -1304,14 +1458,17 @@ func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt {
                defer un(trace(p, "BranchStmt"))
        }
 
-       s := &ast.BranchStmt{p.pos, tok, nil}
-       p.expect(tok)
+       pos := p.expect(tok)
+       var label *ast.Ident
        if tok != token.FALLTHROUGH && p.tok == token.IDENT {
-               s.Label = p.parseIdent()
+               label = p.parseIdent()
+               // add to list of unresolved targets
+               n := len(p.targetStack) - 1
+               p.targetStack[n] = append(p.targetStack[n], label)
        }
        p.expectSemi()
 
-       return s
+       return &ast.BranchStmt{pos, tok, label}
 }
 
 
@@ -1333,6 +1490,8 @@ func (p *parser) parseIfStmt() *ast.IfStmt {
        }
 
        pos := p.expect(token.IF)
+       p.openScope()
+       defer p.closeScope()
 
        var s ast.Stmt
        var x ast.Expr
@@ -1368,28 +1527,6 @@ func (p *parser) parseIfStmt() *ast.IfStmt {
 }
 
 
-func (p *parser) parseCaseClause() *ast.CaseClause {
-       if p.trace {
-               defer un(trace(p, "CaseClause"))
-       }
-
-       // SwitchCase
-       pos := p.pos
-       var x []ast.Expr
-       if p.tok == token.CASE {
-               p.next()
-               x = p.parseExprList()
-       } else {
-               p.expect(token.DEFAULT)
-       }
-
-       colon := p.expect(token.COLON)
-       body := p.parseStmtList()
-
-       return &ast.CaseClause{pos, x, colon, body}
-}
-
-
 func (p *parser) parseTypeList() (list []ast.Expr) {
        if p.trace {
                defer un(trace(p, "TypeList"))
@@ -1405,25 +1542,30 @@ func (p *parser) parseTypeList() (list []ast.Expr) {
 }
 
 
-func (p *parser) parseTypeCaseClause() *ast.TypeCaseClause {
+func (p *parser) parseCaseClause(exprSwitch bool) *ast.CaseClause {
        if p.trace {
-               defer un(trace(p, "TypeCaseClause"))
+               defer un(trace(p, "CaseClause"))
        }
 
-       // TypeSwitchCase
        pos := p.pos
-       var types []ast.Expr
+       var list []ast.Expr
        if p.tok == token.CASE {
                p.next()
-               types = p.parseTypeList()
+               if exprSwitch {
+                       list = p.parseExprList()
+               } else {
+                       list = p.parseTypeList()
+               }
        } else {
                p.expect(token.DEFAULT)
        }
 
        colon := p.expect(token.COLON)
+       p.openScope()
        body := p.parseStmtList()
+       p.closeScope()
 
-       return &ast.TypeCaseClause{pos, types, colon, body}
+       return &ast.CaseClause{pos, list, colon, body}
 }
 
 
@@ -1447,6 +1589,8 @@ func (p *parser) parseSwitchStmt() ast.Stmt {
        }
 
        pos := p.expect(token.SWITCH)
+       p.openScope()
+       defer p.closeScope()
 
        var s1, s2 ast.Stmt
        if p.tok != token.LBRACE {
@@ -1466,28 +1610,21 @@ func (p *parser) parseSwitchStmt() ast.Stmt {
                p.exprLev = prevLev
        }
 
-       if isExprSwitch(s2) {
-               lbrace := p.expect(token.LBRACE)
-               var list []ast.Stmt
-               for p.tok == token.CASE || p.tok == token.DEFAULT {
-                       list = append(list, p.parseCaseClause())
-               }
-               rbrace := p.expect(token.RBRACE)
-               body := &ast.BlockStmt{lbrace, list, rbrace}
-               p.expectSemi()
-               return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body}
-       }
-
-       // type switch
-       // TODO(gri): do all the checks!
+       exprSwitch := isExprSwitch(s2)
        lbrace := p.expect(token.LBRACE)
        var list []ast.Stmt
        for p.tok == token.CASE || p.tok == token.DEFAULT {
-               list = append(list, p.parseTypeCaseClause())
+               list = append(list, p.parseCaseClause(exprSwitch))
        }
        rbrace := p.expect(token.RBRACE)
        p.expectSemi()
        body := &ast.BlockStmt{lbrace, list, rbrace}
+
+       if exprSwitch {
+               return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body}
+       }
+       // type switch
+       // TODO(gri): do all the checks!
        return &ast.TypeSwitchStmt{pos, s1, s2, body}
 }
 
@@ -1497,7 +1634,7 @@ func (p *parser) parseCommClause() *ast.CommClause {
                defer un(trace(p, "CommClause"))
        }
 
-       // CommCase
+       p.openScope()
        pos := p.pos
        var comm ast.Stmt
        if p.tok == token.CASE {
@@ -1518,7 +1655,7 @@ func (p *parser) parseCommClause() *ast.CommClause {
                        pos := p.pos
                        tok := p.tok
                        var rhs ast.Expr
-                       if p.tok == token.ASSIGN || p.tok == token.DEFINE {
+                       if tok == token.ASSIGN || tok == token.DEFINE {
                                // RecvStmt with assignment
                                if len(lhs) > 2 {
                                        p.errorExpected(lhs[0].Pos(), "1 or 2 expressions")
@@ -1527,6 +1664,9 @@ func (p *parser) parseCommClause() *ast.CommClause {
                                }
                                p.next()
                                rhs = p.parseExpr()
+                               if tok == token.DEFINE {
+                                       p.shortVarDecl(p.makeIdentList(lhs))
+                               }
                        } else {
                                // rhs must be single receive operation
                                if len(lhs) > 1 {
@@ -1552,6 +1692,7 @@ func (p *parser) parseCommClause() *ast.CommClause {
 
        colon := p.expect(token.COLON)
        body := p.parseStmtList()
+       p.closeScope()
 
        return &ast.CommClause{pos, comm, colon, body}
 }
@@ -1582,6 +1723,8 @@ func (p *parser) parseForStmt() ast.Stmt {
        }
 
        pos := p.expect(token.FOR)
+       p.openScope()
+       defer p.closeScope()
 
        var s1, s2, s3 ast.Stmt
        if p.tok != token.LBRACE {
@@ -1631,18 +1774,16 @@ func (p *parser) parseForStmt() ast.Stmt {
                        return &ast.BadStmt{pos, body.End()}
                }
                if rhs, isUnary := as.Rhs[0].(*ast.UnaryExpr); isUnary && rhs.Op == token.RANGE {
-                       // rhs is range expression; check lhs
+                       // rhs is range expression
+                       // (any short variable declaration was handled by parseSimpleStat above)
                        return &ast.RangeStmt{pos, key, value, as.TokPos, as.Tok, rhs.X, body}
-               } else {
-                       p.errorExpected(s2.Pos(), "range clause")
-                       return &ast.BadStmt{pos, body.End()}
                }
-       } else {
-               // regular for statement
-               return &ast.ForStmt{pos, s1, p.makeExpr(s2), s3, body}
+               p.errorExpected(s2.Pos(), "range clause")
+               return &ast.BadStmt{pos, body.End()}
        }
 
-       panic("unreachable")
+       // regular for statement
+       return &ast.ForStmt{pos, s1, p.makeExpr(s2), s3, body}
 }
 
 
@@ -1706,36 +1847,37 @@ func (p *parser) parseStmt() (s ast.Stmt) {
 // ----------------------------------------------------------------------------
 // Declarations
 
-type parseSpecFunction func(p *parser, doc *ast.CommentGroup) ast.Spec
+type parseSpecFunction func(p *parser, doc *ast.CommentGroup, iota int) ast.Spec
 
 
-func parseImportSpec(p *parser, doc *ast.CommentGroup) ast.Spec {
+func parseImportSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec {
        if p.trace {
                defer un(trace(p, "ImportSpec"))
        }
 
        var ident *ast.Ident
-       if p.tok == token.PERIOD {
+       switch p.tok {
+       case token.PERIOD:
                ident = &ast.Ident{p.pos, ".", nil}
                p.next()
-       } else if p.tok == token.IDENT {
+       case token.IDENT:
                ident = p.parseIdent()
        }
 
        var path *ast.BasicLit
        if p.tok == token.STRING {
-               path = &ast.BasicLit{p.pos, p.tok, p.lit}
+               path = &ast.BasicLit{p.pos, p.tok, p.lit()}
                p.next()
        } else {
                p.expect(token.STRING) // use expect() error handling
        }
-       p.expectSemi()
+       p.expectSemi() // call before accessing p.linecomment
 
        return &ast.ImportSpec{doc, ident, path, p.lineComment}
 }
 
 
-func parseConstSpec(p *parser, doc *ast.CommentGroup) ast.Spec {
+func parseConstSpec(p *parser, doc *ast.CommentGroup, iota int) ast.Spec {
        if p.trace {
                defer un(trace(p, "ConstSpec"))
        }
@@ -1743,30 +1885,44 @@ func parseConstSpec(p *parser, doc *ast.CommentGroup) ast.Spec {
        idents := p.parseIdentList()
        typ := p.tryType()
        var values []ast.Expr
-       if typ != nil || p.tok == token.ASSIGN {
+       if typ != nil || p.tok == token.ASSIGN || iota == 0 {
                p.expect(token.ASSIGN)
                values = p.parseExprList()
        }
-       p.expectSemi()
+       p.expectSemi() // call before accessing p.linecomment
+
+       // Go spec: The scope of a constant or variable identifier declared inside
+       // a function begins at the end of the ConstSpec or VarSpec and ends at
+       // the end of the innermost containing block.
+       // (Global identifiers are resolved in a separate phase after parsing.)
+       spec := &ast.ValueSpec{doc, idents, typ, values, p.lineComment}
+       p.declare(spec, p.topScope, ast.Con, idents...)
 
-       return &ast.ValueSpec{doc, idents, typ, values, p.lineComment}
+       return spec
 }
 
 
-func parseTypeSpec(p *parser, doc *ast.CommentGroup) ast.Spec {
+func parseTypeSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec {
        if p.trace {
                defer un(trace(p, "TypeSpec"))
        }
 
        ident := p.parseIdent()
        typ := p.parseType()
-       p.expectSemi()
+       p.expectSemi() // call before accessing p.linecomment
+
+       // Go spec: The scope of a type identifier declared inside a function begins
+       // at the identifier in the TypeSpec and ends at the end of the innermost
+       // containing block.
+       // (Global identifiers are resolved in a separate phase after parsing.)
+       spec := &ast.TypeSpec{doc, ident, typ, p.lineComment}
+       p.declare(spec, p.topScope, ast.Typ, ident)
 
-       return &ast.TypeSpec{doc, ident, typ, p.lineComment}
+       return spec
 }
 
 
-func parseVarSpec(p *parser, doc *ast.CommentGroup) ast.Spec {
+func parseVarSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec {
        if p.trace {
                defer un(trace(p, "VarSpec"))
        }
@@ -1778,9 +1934,16 @@ func parseVarSpec(p *parser, doc *ast.CommentGroup) ast.Spec {
                p.expect(token.ASSIGN)
                values = p.parseExprList()
        }
-       p.expectSemi()
+       p.expectSemi() // call before accessing p.linecomment
+
+       // Go spec: The scope of a constant or variable identifier declared inside
+       // a function begins at the end of the ConstSpec or VarSpec and ends at
+       // the end of the innermost containing block.
+       // (Global identifiers are resolved in a separate phase after parsing.)
+       spec := &ast.ValueSpec{doc, idents, typ, values, p.lineComment}
+       p.declare(spec, p.topScope, ast.Var, idents...)
 
-       return &ast.ValueSpec{doc, idents, typ, values, p.lineComment}
+       return spec
 }
 
 
@@ -1796,26 +1959,26 @@ func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.Gen
        if p.tok == token.LPAREN {
                lparen = p.pos
                p.next()
-               for p.tok != token.RPAREN && p.tok != token.EOF {
-                       list = append(list, f(p, p.leadComment))
+               for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ {
+                       list = append(list, f(p, p.leadComment, iota))
                }
                rparen = p.expect(token.RPAREN)
                p.expectSemi()
        } else {
-               list = append(list, f(p, nil))
+               list = append(list, f(p, nil, 0))
        }
 
        return &ast.GenDecl{doc, pos, keyword, lparen, list, rparen}
 }
 
 
-func (p *parser) parseReceiver() *ast.FieldList {
+func (p *parser) parseReceiver(scope *ast.Scope) *ast.FieldList {
        if p.trace {
                defer un(trace(p, "Receiver"))
        }
 
        pos := p.pos
-       par := p.parseParameters(false)
+       par := p.parseParameters(scope, false)
 
        // must have exactly one receiver
        if par.NumFields() != 1 {
@@ -1844,22 +2007,37 @@ func (p *parser) parseFuncDecl() *ast.FuncDecl {
 
        doc := p.leadComment
        pos := p.expect(token.FUNC)
+       scope := ast.NewScope(p.topScope) // function scope
 
        var recv *ast.FieldList
        if p.tok == token.LPAREN {
-               recv = p.parseReceiver()
+               recv = p.parseReceiver(scope)
        }
 
        ident := p.parseIdent()
-       params, results := p.parseSignature()
+
+       params, results := p.parseSignature(scope)
 
        var body *ast.BlockStmt
        if p.tok == token.LBRACE {
-               body = p.parseBody()
+               body = p.parseBody(scope)
        }
        p.expectSemi()
 
-       return &ast.FuncDecl{doc, recv, ident, &ast.FuncType{pos, params, results}, body}
+       decl := &ast.FuncDecl{doc, recv, ident, &ast.FuncType{pos, params, results}, body}
+       if recv == nil {
+               // Go spec: The scope of an identifier denoting a constant, type,
+               // variable, or function (but not method) declared at top level
+               // (outside any function) is the package block.
+               //
+               // init() functions cannot be referred to and there may
+               // be more than one - don't put them in the pkgScope
+               if ident.Name != "init" {
+                       p.declare(decl, p.pkgScope, ast.Fun, ident)
+               }
+       }
+
+       return decl
 }
 
 
@@ -1918,6 +2096,8 @@ func (p *parser) parseFile() *ast.File {
        // package clause
        doc := p.leadComment
        pos := p.expect(token.PACKAGE)
+       // Go spec: The package clause is not a declaration;
+       // the package name does not appear in any scope.
        ident := p.parseIdent()
        p.expectSemi()
 
@@ -1940,5 +2120,20 @@ func (p *parser) parseFile() *ast.File {
                }
        }
 
-       return &ast.File{doc, pos, ident, decls, p.comments}
+       if p.topScope != p.pkgScope {
+               panic("internal error: imbalanced scopes")
+       }
+
+       // resolve global identifiers within the same file
+       i := 0
+       for _, ident := range p.unresolved {
+               // i <= index for current ident
+               ident.Obj = p.pkgScope.Lookup(ident.Name)
+               if ident.Obj == nil {
+                       p.unresolved[i] = ident
+                       i++
+               }
+       }
+
+       return &ast.File{doc, pos, ident, decls, p.pkgScope, p.unresolved[0:i], p.comments}
 }
index 38535627a75ec76029993c986f53e25b04a82c00..2f1ee6bfc094594355ad638d5261916b4833ceb3 100644 (file)
@@ -21,6 +21,7 @@ var illegalInputs = []interface{}{
        `package p; func f() { if /* should have condition */ {} };`,
        `package p; func f() { if ; /* should have condition */ {} };`,
        `package p; func f() { if f(); /* should have condition */ {} };`,
+       `package p; const c; /* should have constant value */`,
 }
 
 
@@ -73,7 +74,7 @@ var validFiles = []string{
 
 func TestParse3(t *testing.T) {
        for _, filename := range validFiles {
-               _, err := ParseFile(fset, filename, nil, 0)
+               _, err := ParseFile(fset, filename, nil, DeclarationErrors)
                if err != nil {
                        t.Errorf("ParseFile(%s): %v", filename, err)
                }
index 7933c2f182003a0ab631d5c785b005bcf2b9bed1..2238b6bedc80ba518b209846d67e29230a95c3ca 100644 (file)
@@ -108,17 +108,6 @@ func (p *printer) identList(list []*ast.Ident, indent bool, multiLine *bool) {
 }
 
 
-// Compute the key size of a key:value expression.
-// Returns 0 if the expression doesn't fit onto a single line.
-func (p *printer) keySize(pair *ast.KeyValueExpr) int {
-       if p.nodeSize(pair, infinity) <= infinity {
-               // entire expression fits on one line - return key size
-               return p.nodeSize(pair.Key, infinity)
-       }
-       return 0
-}
-
-
 // Print a list of expressions. If the list spans multiple
 // source lines, the original line breaks are respected between
 // expressions. Sets multiLine to true if the list spans multiple
@@ -204,17 +193,21 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp
                //           the key and the node size into the decision process
                useFF := true
 
-               // determine size
+               // determine element size: all bets are off if we don't have
+               // position information for the previous and next token (likely
+               // generated code - simply ignore the size in this case by setting
+               // it to 0)
                prevSize := size
                const infinity = 1e6 // larger than any source line
                size = p.nodeSize(x, infinity)
                pair, isPair := x.(*ast.KeyValueExpr)
-               if size <= infinity {
+               if size <= infinity && prev.IsValid() && next.IsValid() {
                        // x fits on a single line
                        if isPair {
                                size = p.nodeSize(pair.Key, infinity) // size <= infinity
                        }
                } else {
+                       // size too large or we don't have good layout information
                        size = 0
                }
 
@@ -244,7 +237,6 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp
                                // lines are broken using newlines so comments remain aligned
                                // unless forceFF is set or there are multiple expressions on
                                // the same line in which case formfeed is used
-                               // broken with a formfeed
                                if p.linebreak(line, linebreakMin, ws, useFF || prevBreak+1 < i) {
                                        ws = ignore
                                        *multiLine = true
@@ -375,7 +367,7 @@ func (p *printer) setLineComment(text string) {
 }
 
 
-func (p *printer) fieldList(fields *ast.FieldList, isIncomplete bool, ctxt exprContext) {
+func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool) {
        p.nesting++
        defer func() {
                p.nesting--
@@ -384,15 +376,15 @@ func (p *printer) fieldList(fields *ast.FieldList, isIncomplete bool, ctxt exprC
        lbrace := fields.Opening
        list := fields.List
        rbrace := fields.Closing
+       srcIsOneLine := lbrace.IsValid() && rbrace.IsValid() && p.fset.Position(lbrace).Line == p.fset.Position(rbrace).Line
 
-       if !isIncomplete && !p.commentBefore(p.fset.Position(rbrace)) {
+       if !isIncomplete && !p.commentBefore(p.fset.Position(rbrace)) && srcIsOneLine {
                // possibly a one-line struct/interface
                if len(list) == 0 {
                        // no blank between keyword and {} in this case
                        p.print(lbrace, token.LBRACE, rbrace, token.RBRACE)
                        return
-               } else if ctxt&(compositeLit|structType) == compositeLit|structType &&
-                       p.isOneLineFieldList(list) { // for now ignore interfaces
+               } else if isStruct && p.isOneLineFieldList(list) { // for now ignore interfaces
                        // small enough - print on one line
                        // (don't use identList and ignore source line breaks)
                        p.print(lbrace, token.LBRACE, blank)
@@ -414,7 +406,7 @@ func (p *printer) fieldList(fields *ast.FieldList, isIncomplete bool, ctxt exprC
 
        // at least one entry or incomplete
        p.print(blank, lbrace, token.LBRACE, indent, formfeed)
-       if ctxt&structType != 0 {
+       if isStruct {
 
                sep := vtab
                if len(list) == 1 {
@@ -497,15 +489,6 @@ func (p *printer) fieldList(fields *ast.FieldList, isIncomplete bool, ctxt exprC
 // ----------------------------------------------------------------------------
 // Expressions
 
-// exprContext describes the syntactic environment in which an expression node is printed.
-type exprContext uint
-
-const (
-       compositeLit exprContext = 1 << iota
-       structType
-)
-
-
 func walkBinary(e *ast.BinaryExpr) (has4, has5 bool, maxProblem int) {
        switch e.Op.Precedence() {
        case 4:
@@ -650,7 +633,7 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int, multiL
        printBlank := prec < cutoff
 
        ws := indent
-       p.expr1(x.X, prec, depth+diffPrec(x.X, prec), 0, multiLine)
+       p.expr1(x.X, prec, depth+diffPrec(x.X, prec), multiLine)
        if printBlank {
                p.print(blank)
        }
@@ -669,7 +652,7 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int, multiL
        if printBlank {
                p.print(blank)
        }
-       p.expr1(x.Y, prec+1, depth+1, 0, multiLine)
+       p.expr1(x.Y, prec+1, depth+1, multiLine)
        if ws == ignore {
                p.print(unindent)
        }
@@ -742,7 +725,7 @@ func selectorExprList(expr ast.Expr) (list []ast.Expr) {
 
 
 // Sets multiLine to true if the expression spans multiple lines.
-func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multiLine *bool) {
+func (p *printer) expr1(expr ast.Expr, prec1, depth int, multiLine *bool) {
        p.print(expr.Pos())
 
        switch x := expr.(type) {
@@ -792,7 +775,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
                                // TODO(gri) Remove this code if it cannot be reached.
                                p.print(blank)
                        }
-                       p.expr1(x.X, prec, depth, 0, multiLine)
+                       p.expr1(x.X, prec, depth, multiLine)
                }
 
        case *ast.BasicLit:
@@ -818,7 +801,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
                p.exprList(token.NoPos, parts, depth, periodSep, multiLine, token.NoPos)
 
        case *ast.TypeAssertExpr:
-               p.expr1(x.X, token.HighestPrec, depth, 0, multiLine)
+               p.expr1(x.X, token.HighestPrec, depth, multiLine)
                p.print(token.PERIOD, token.LPAREN)
                if x.Type != nil {
                        p.expr(x.Type, multiLine)
@@ -829,14 +812,14 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
 
        case *ast.IndexExpr:
                // TODO(gri): should treat[] like parentheses and undo one level of depth
-               p.expr1(x.X, token.HighestPrec, 1, 0, multiLine)
+               p.expr1(x.X, token.HighestPrec, 1, multiLine)
                p.print(x.Lbrack, token.LBRACK)
                p.expr0(x.Index, depth+1, multiLine)
                p.print(x.Rbrack, token.RBRACK)
 
        case *ast.SliceExpr:
                // TODO(gri): should treat[] like parentheses and undo one level of depth
-               p.expr1(x.X, token.HighestPrec, 1, 0, multiLine)
+               p.expr1(x.X, token.HighestPrec, 1, multiLine)
                p.print(x.Lbrack, token.LBRACK)
                if x.Low != nil {
                        p.expr0(x.Low, depth+1, multiLine)
@@ -856,7 +839,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
                if len(x.Args) > 1 {
                        depth++
                }
-               p.expr1(x.Fun, token.HighestPrec, depth, 0, multiLine)
+               p.expr1(x.Fun, token.HighestPrec, depth, multiLine)
                p.print(x.Lparen, token.LPAREN)
                p.exprList(x.Lparen, x.Args, depth, commaSep|commaTerm, multiLine, x.Rparen)
                if x.Ellipsis.IsValid() {
@@ -867,7 +850,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
        case *ast.CompositeLit:
                // composite literal elements that are composite literals themselves may have the type omitted
                if x.Type != nil {
-                       p.expr1(x.Type, token.HighestPrec, depth, compositeLit, multiLine)
+                       p.expr1(x.Type, token.HighestPrec, depth, multiLine)
                }
                p.print(x.Lbrace, token.LBRACE)
                p.exprList(x.Lbrace, x.Elts, 1, commaSep|commaTerm, multiLine, x.Rbrace)
@@ -892,7 +875,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
 
        case *ast.StructType:
                p.print(token.STRUCT)
-               p.fieldList(x.Fields, x.Incomplete, ctxt|structType)
+               p.fieldList(x.Fields, true, x.Incomplete)
 
        case *ast.FuncType:
                p.print(token.FUNC)
@@ -900,7 +883,7 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
 
        case *ast.InterfaceType:
                p.print(token.INTERFACE)
-               p.fieldList(x.Methods, x.Incomplete, ctxt)
+               p.fieldList(x.Methods, false, x.Incomplete)
 
        case *ast.MapType:
                p.print(token.MAP, token.LBRACK)
@@ -929,14 +912,14 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int, ctxt exprContext, multi
 
 
 func (p *printer) expr0(x ast.Expr, depth int, multiLine *bool) {
-       p.expr1(x, token.LowestPrec, depth, 0, multiLine)
+       p.expr1(x, token.LowestPrec, depth, multiLine)
 }
 
 
 // Sets multiLine to true if the expression spans multiple lines.
 func (p *printer) expr(x ast.Expr, multiLine *bool) {
        const depth = 1
-       p.expr1(x, token.LowestPrec, depth, 0, multiLine)
+       p.expr1(x, token.LowestPrec, depth, multiLine)
 }
 
 
@@ -1145,9 +1128,9 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) {
                }
 
        case *ast.CaseClause:
-               if s.Values != nil {
+               if s.List != nil {
                        p.print(token.CASE)
-                       p.exprList(s.Pos(), s.Values, 1, blankStart|commaSep, multiLine, s.Colon)
+                       p.exprList(s.Pos(), s.List, 1, blankStart|commaSep, multiLine, s.Colon)
                } else {
                        p.print(token.DEFAULT)
                }
@@ -1160,16 +1143,6 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) {
                p.block(s.Body, 0)
                *multiLine = true
 
-       case *ast.TypeCaseClause:
-               if s.Types != nil {
-                       p.print(token.CASE)
-                       p.exprList(s.Pos(), s.Types, 1, blankStart|commaSep, multiLine, s.Colon)
-               } else {
-                       p.print(token.DEFAULT)
-               }
-               p.print(s.Colon, token.COLON)
-               p.stmtList(s.Body, 1, nextIsRBrace)
-
        case *ast.TypeSwitchStmt:
                p.print(token.SWITCH)
                if s.Init != nil {
@@ -1331,13 +1304,23 @@ func (p *printer) genDecl(d *ast.GenDecl, multiLine *bool) {
 // any control chars. Otherwise, the result is > maxSize.
 //
 func (p *printer) nodeSize(n ast.Node, maxSize int) (size int) {
+       // nodeSize invokes the printer, which may invoke nodeSize
+       // recursively. For deep composite literal nests, this can
+       // lead to an exponential algorithm. Remember previous
+       // results to prune the recursion (was issue 1628).
+       if size, found := p.nodeSizes[n]; found {
+               return size
+       }
+
        size = maxSize + 1 // assume n doesn't fit
+       p.nodeSizes[n] = size
+
        // nodeSize computation must be indendent of particular
        // style so that we always get the same decision; print
        // in RawFormat
        cfg := Config{Mode: RawFormat}
        var buf bytes.Buffer
-       if _, err := cfg.Fprint(&buf, p.fset, n); err != nil {
+       if _, err := cfg.fprint(&buf, p.fset, n, p.nodeSizes); err != nil {
                return
        }
        if buf.Len() <= maxSize {
@@ -1347,6 +1330,7 @@ func (p *printer) nodeSize(n ast.Node, maxSize int) (size int) {
                        }
                }
                size = buf.Len() // n fits
+               p.nodeSizes[n] = size
        }
        return
 }
index 48e2af1b736af7ac1b74fea78b38d9fca17792b8..a43e4a12c774688a0fa362dc16d390cfdfe5ad19 100644 (file)
@@ -12,7 +12,7 @@ import (
        "go/token"
        "io"
        "os"
-       "path"
+       "path/filepath"
        "runtime"
        "tabwriter"
 )
@@ -94,22 +94,23 @@ type printer struct {
        // written using writeItem.
        last token.Position
 
-       // HTML support
-       lastTaggedLine int // last line for which a line tag was written
-
        // The list of all source comments, in order of appearance.
        comments        []*ast.CommentGroup // may be nil
        cindex          int                 // current comment index
        useNodeComments bool                // if not set, ignore lead and line comments of nodes
+
+       // Cache of already computed node sizes.
+       nodeSizes map[ast.Node]int
 }
 
 
-func (p *printer) init(output io.Writer, cfg *Config, fset *token.FileSet) {
+func (p *printer) init(output io.Writer, cfg *Config, fset *token.FileSet, nodeSizes map[ast.Node]int) {
        p.output = output
        p.Config = *cfg
        p.fset = fset
        p.errors = make(chan os.Error)
        p.buffer = make([]whiteSpace, 0, 16) // whitespace sequences are short
+       p.nodeSizes = nodeSizes
 }
 
 
@@ -244,7 +245,7 @@ func (p *printer) writeItem(pos token.Position, data []byte) {
        }
        if debug {
                // do not update p.pos - use write0
-               _, filename := path.Split(pos.Filename)
+               _, filename := filepath.Split(pos.Filename)
                p.write0([]byte(fmt.Sprintf("[%s:%d:%d]", filename, pos.Line, pos.Column)))
        }
        p.write(data)
@@ -994,13 +995,8 @@ type Config struct {
 }
 
 
-// Fprint "pretty-prints" an AST node to output and returns the number
-// of bytes written and an error (if any) for a given configuration cfg.
-// Position information is interpreted relative to the file set fset.
-// The node type must be *ast.File, or assignment-compatible to ast.Expr,
-// ast.Decl, ast.Spec, or ast.Stmt.
-//
-func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) (int, os.Error) {
+// fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
+func (cfg *Config) fprint(output io.Writer, fset *token.FileSet, node interface{}, nodeSizes map[ast.Node]int) (int, os.Error) {
        // redirect output through a trimmer to eliminate trailing whitespace
        // (Input to a tabwriter must be untrimmed since trailing tabs provide
        // formatting information. The tabwriter could provide trimming
@@ -1029,7 +1025,7 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{
 
        // setup printer and print node
        var p printer
-       p.init(output, cfg, fset)
+       p.init(output, cfg, fset, nodeSizes)
        go func() {
                switch n := node.(type) {
                case ast.Expr:
@@ -1076,6 +1072,17 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{
 }
 
 
+// Fprint "pretty-prints" an AST node to output and returns the number
+// of bytes written and an error (if any) for a given configuration cfg.
+// Position information is interpreted relative to the file set fset.
+// The node type must be *ast.File, or assignment-compatible to ast.Expr,
+// ast.Decl, ast.Spec, or ast.Stmt.
+//
+func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) (int, os.Error) {
+       return cfg.fprint(output, fset, node, make(map[ast.Node]int))
+}
+
+
 // Fprint "pretty-prints" an AST node to output.
 // It calls Config.Fprint with default settings.
 //
index 565075aa20cd948ee07db7c33679fcdb48108b16..3ff087e2993b58358fdddea3e08b6fd8e14400f8 100644 (file)
@@ -11,8 +11,9 @@ import (
        "go/ast"
        "go/parser"
        "go/token"
-       "path"
+       "path/filepath"
        "testing"
+       "time"
 )
 
 
@@ -45,7 +46,7 @@ const (
 )
 
 
-func check(t *testing.T, source, golden string, mode checkMode) {
+func runcheck(t *testing.T, source, golden string, mode checkMode) {
        // parse source
        prog, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
        if err != nil {
@@ -109,6 +110,32 @@ func check(t *testing.T, source, golden string, mode checkMode) {
 }
 
 
+func check(t *testing.T, source, golden string, mode checkMode) {
+       // start a timer to produce a time-out signal
+       tc := make(chan int)
+       go func() {
+               time.Sleep(20e9) // plenty of a safety margin, even for very slow machines
+               tc <- 0
+       }()
+
+       // run the test
+       cc := make(chan int)
+       go func() {
+               runcheck(t, source, golden, mode)
+               cc <- 0
+       }()
+
+       // wait for the first finisher
+       select {
+       case <-tc:
+               // test running past time out
+               t.Errorf("%s: running too slowly", source)
+       case <-cc:
+               // test finished within alloted time margin
+       }
+}
+
+
 type entry struct {
        source, golden string
        mode           checkMode
@@ -124,13 +151,14 @@ var data = []entry{
        {"expressions.input", "expressions.raw", rawFormat},
        {"declarations.input", "declarations.golden", 0},
        {"statements.input", "statements.golden", 0},
+       {"slow.input", "slow.golden", 0},
 }
 
 
 func TestFiles(t *testing.T) {
        for _, e := range data {
-               source := path.Join(dataDir, e.source)
-               golden := path.Join(dataDir, e.golden)
+               source := filepath.Join(dataDir, e.source)
+               golden := filepath.Join(dataDir, e.golden)
                check(t, source, golden, e.mode)
                // TODO(gri) check that golden is idempotent
                //check(t, golden, golden, e.mode);
index 7f18f338a633ccf8fc45a4e434b1d61b10e55e44..314d3213c740a0ecc3972d5b528ff3f5c5a373ca 100644 (file)
@@ -224,11 +224,7 @@ func _() {
        _ = struct{ x int }{0}
        _ = struct{ x, y, z int }{0, 1, 2}
        _ = struct{ int }{0}
-       _ = struct {
-               s struct {
-                       int
-               }
-       }{struct{ int }{0}}     // compositeLit context not propagated => multiLine result
+       _ = struct{ s struct{ int } }{struct{ int }{0}}
 }
 
 
index 6bcd9b5f89ebd37a2d368c342e9026e7db952886..cac22af431312fa08d5eb23318183593e2c5b596 100644 (file)
@@ -224,7 +224,7 @@ func _() {
        _ = struct{ x int }{0}
        _ = struct{ x, y, z int }{0, 1, 2}
        _ = struct{ int }{0}
-       _ = struct{ s struct { int } }{struct{ int}{0}}  // compositeLit context not propagated => multiLine result
+       _ = struct{ s struct { int } }{struct{ int}{0} }
 }
 
 
index f1944c94bb4d9ebd3fbac8fc480cf9e204d433c1..f22ceeb476f9083af9b88a498a383b56828d15dc 100644 (file)
@@ -224,11 +224,7 @@ func _() {
        _ = struct{ x int }{0}
        _ = struct{ x, y, z int }{0, 1, 2}
        _ = struct{ int }{0}
-       _ = struct {
-               s struct {
-                       int
-               }
-       }{struct{ int }{0}}     // compositeLit context not propagated => multiLine result
+       _ = struct{ s struct{ int } }{struct{ int }{0}}
 }
 
 
diff --git a/libgo/go/go/printer/testdata/slow.golden b/libgo/go/go/printer/testdata/slow.golden
new file mode 100644 (file)
index 0000000..43a15cb
--- /dev/null
@@ -0,0 +1,85 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package deepequal_test
+
+import (
+       "testing"
+       "google3/spam/archer/frontend/deepequal"
+)
+
+func TestTwoNilValues(t *testing.T) {
+       if err := deepequal.Check(nil, nil); err != nil {
+               t.Errorf("expected nil, saw %v", err)
+       }
+}
+
+type Foo struct {
+       bar     *Bar
+       bang    *Bar
+}
+
+type Bar struct {
+       baz     *Baz
+       foo     []*Foo
+}
+
+type Baz struct {
+       entries         map[int]interface{}
+       whatever        string
+}
+
+func newFoo() *Foo {
+       return &Foo{bar: &Bar{baz: &Baz{
+               entries: map[int]interface{}{
+                       42:     &Foo{},
+                       21:     &Bar{},
+                       11:     &Baz{whatever: "it's just a test"}}}},
+               bang: &Bar{foo: []*Foo{
+                       &Foo{bar: &Bar{baz: &Baz{
+                               entries: map[int]interface{}{
+                                       43:     &Foo{},
+                                       22:     &Bar{},
+                                       13:     &Baz{whatever: "this is nuts"}}}},
+                               bang: &Bar{foo: []*Foo{
+                                       &Foo{bar: &Bar{baz: &Baz{
+                                               entries: map[int]interface{}{
+                                                       61:     &Foo{},
+                                                       71:     &Bar{},
+                                                       11:     &Baz{whatever: "no, it's Go"}}}},
+                                               bang: &Bar{foo: []*Foo{
+                                                       &Foo{bar: &Bar{baz: &Baz{
+                                                               entries: map[int]interface{}{
+                                                                       0:      &Foo{},
+                                                                       -2:     &Bar{},
+                                                                       -11:    &Baz{whatever: "we need to go deeper"}}}},
+                                                               bang: &Bar{foo: []*Foo{
+                                                                       &Foo{bar: &Bar{baz: &Baz{
+                                                                               entries: map[int]interface{}{
+                                                                                       -2:     &Foo{},
+                                                                                       -5:     &Bar{},
+                                                                                       -7:     &Baz{whatever: "are you serious?"}}}},
+                                                                               bang:   &Bar{foo: []*Foo{}}},
+                                                                       &Foo{bar: &Bar{baz: &Baz{
+                                                                               entries: map[int]interface{}{
+                                                                                       -100:   &Foo{},
+                                                                                       50:     &Bar{},
+                                                                                       20:     &Baz{whatever: "na, not really ..."}}}},
+                                                                               bang:   &Bar{foo: []*Foo{}}}}}}}}},
+                                       &Foo{bar: &Bar{baz: &Baz{
+                                               entries: map[int]interface{}{
+                                                       2:      &Foo{},
+                                                       1:      &Bar{},
+                                                       -1:     &Baz{whatever: "... it's just a test."}}}},
+                                               bang:   &Bar{foo: []*Foo{}}}}}}}}}
+}
+
+func TestElaborate(t *testing.T) {
+       a := newFoo()
+       b := newFoo()
+
+       if err := deepequal.Check(a, b); err != nil {
+               t.Errorf("expected nil, saw %v", err)
+       }
+}
diff --git a/libgo/go/go/printer/testdata/slow.input b/libgo/go/go/printer/testdata/slow.input
new file mode 100644 (file)
index 0000000..0e5a23d
--- /dev/null
@@ -0,0 +1,85 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package deepequal_test
+
+import (
+        "testing"
+        "google3/spam/archer/frontend/deepequal"
+)
+
+func TestTwoNilValues(t *testing.T) {
+        if err := deepequal.Check(nil, nil); err != nil {
+                t.Errorf("expected nil, saw %v", err)
+        }
+}
+
+type Foo struct {
+        bar *Bar
+        bang *Bar
+}
+
+type Bar struct {
+        baz *Baz
+        foo []*Foo
+}
+
+type Baz struct {
+        entries  map[int]interface{}
+        whatever string
+}
+
+func newFoo() (*Foo) {
+return &Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+42: &Foo{},
+21: &Bar{},
+11: &Baz{ whatever: "it's just a test" }}}},
+        bang: &Bar{foo: []*Foo{
+&Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+43: &Foo{},
+22: &Bar{},
+13: &Baz{ whatever: "this is nuts" }}}},
+        bang: &Bar{foo: []*Foo{
+&Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+61: &Foo{},
+71: &Bar{},
+11: &Baz{ whatever: "no, it's Go" }}}},
+        bang: &Bar{foo: []*Foo{
+&Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+0: &Foo{},
+-2: &Bar{},
+-11: &Baz{ whatever: "we need to go deeper" }}}},
+        bang: &Bar{foo: []*Foo{
+&Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+-2: &Foo{},
+-5: &Bar{},
+-7: &Baz{ whatever: "are you serious?" }}}},
+        bang: &Bar{foo: []*Foo{}}},
+&Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+-100: &Foo{},
+50: &Bar{},
+20: &Baz{ whatever: "na, not really ..." }}}},
+        bang: &Bar{foo: []*Foo{}}}}}}}}},
+&Foo{bar: &Bar{ baz: &Baz{
+entries: map[int]interface{}{
+2: &Foo{},
+1: &Bar{},
+-1: &Baz{ whatever: "... it's just a test." }}}},
+        bang: &Bar{foo: []*Foo{}}}}}}}}}
+}
+
+func TestElaborate(t *testing.T) {
+        a := newFoo()
+        b := newFoo()
+
+        if err := deepequal.Check(a, b); err != nil {
+                t.Errorf("expected nil, saw %v", err)
+        }
+}
index 2ae296b3f1595dbafd42590380725a4fa77b2d78..59fed9dffc6bfe919238bd3b9eb1671dd855fb87 100644 (file)
@@ -23,7 +23,7 @@ package scanner
 import (
        "bytes"
        "go/token"
-       "path"
+       "path/filepath"
        "strconv"
        "unicode"
        "utf8"
@@ -118,7 +118,7 @@ func (S *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode uint
                panic("file size does not match src len")
        }
        S.file = file
-       S.dir, _ = path.Split(file.Name())
+       S.dir, _ = filepath.Split(file.Name())
        S.src = src
        S.err = err
        S.mode = mode
@@ -177,13 +177,13 @@ var prefix = []byte("//line ")
 func (S *Scanner) interpretLineComment(text []byte) {
        if bytes.HasPrefix(text, prefix) {
                // get filename and line number, if any
-               if i := bytes.Index(text, []byte{':'}); i > 0 {
+               if i := bytes.LastIndex(text, []byte{':'}); i > 0 {
                        if line, err := strconv.Atoi(string(text[i+1:])); err == nil && line > 0 {
                                // valid //line filename:line comment;
-                               filename := path.Clean(string(text[len(prefix):i]))
-                               if filename[0] != '/' {
+                               filename := filepath.Clean(string(text[len(prefix):i]))
+                               if !filepath.IsAbs(filename) {
                                        // make filename relative to current directory
-                                       filename = path.Join(S.dir, filename)
+                                       filename = filepath.Join(S.dir, filename)
                                }
                                // update scanner position
                                S.file.AddLineInfo(S.lineOffset, filename, line-1) // -1 since comment applies to next line
index c622ff482f34bb19d68751f5f2f15787b332f929..93f34581b7fc6e22be5b851018b2e1810b3d589e 100644 (file)
@@ -7,6 +7,8 @@ package scanner
 import (
        "go/token"
        "os"
+       "path/filepath"
+       "runtime"
        "testing"
 )
 
@@ -443,32 +445,41 @@ func TestSemis(t *testing.T) {
        }
 }
 
-
-var segments = []struct {
+type segment struct {
        srcline  string // a line of source text
        filename string // filename for current token
        line     int    // line number for current token
-}{
+}
+
+var segments = []segment{
        // exactly one token per line since the test consumes one token per segment
-       {"  line1", "dir/TestLineComments", 1},
-       {"\nline2", "dir/TestLineComments", 2},
-       {"\nline3  //line File1.go:100", "dir/TestLineComments", 3}, // bad line comment, ignored
-       {"\nline4", "dir/TestLineComments", 4},
-       {"\n//line File1.go:100\n  line100", "dir/File1.go", 100},
-       {"\n//line File2.go:200\n  line200", "dir/File2.go", 200},
+       {"  line1", filepath.Join("dir", "TestLineComments"), 1},
+       {"\nline2", filepath.Join("dir", "TestLineComments"), 2},
+       {"\nline3  //line File1.go:100", filepath.Join("dir", "TestLineComments"), 3}, // bad line comment, ignored
+       {"\nline4", filepath.Join("dir", "TestLineComments"), 4},
+       {"\n//line File1.go:100\n  line100", filepath.Join("dir", "File1.go"), 100},
+       {"\n//line File2.go:200\n  line200", filepath.Join("dir", "File2.go"), 200},
        {"\n//line :1\n  line1", "dir", 1},
-       {"\n//line foo:42\n  line42", "dir/foo", 42},
-       {"\n //line foo:42\n  line44", "dir/foo", 44},           // bad line comment, ignored
-       {"\n//line foo 42\n  line46", "dir/foo", 46},            // bad line comment, ignored
-       {"\n//line foo:42 extra text\n  line48", "dir/foo", 48}, // bad line comment, ignored
-       {"\n//line /bar:42\n  line42", "/bar", 42},
-       {"\n//line ./foo:42\n  line42", "dir/foo", 42},
-       {"\n//line a/b/c/File1.go:100\n  line100", "dir/a/b/c/File1.go", 100},
+       {"\n//line foo:42\n  line42", filepath.Join("dir", "foo"), 42},
+       {"\n //line foo:42\n  line44", filepath.Join("dir", "foo"), 44},           // bad line comment, ignored
+       {"\n//line foo 42\n  line46", filepath.Join("dir", "foo"), 46},            // bad line comment, ignored
+       {"\n//line foo:42 extra text\n  line48", filepath.Join("dir", "foo"), 48}, // bad line comment, ignored
+       {"\n//line /bar:42\n  line42", string(filepath.Separator) + "bar", 42},
+       {"\n//line ./foo:42\n  line42", filepath.Join("dir", "foo"), 42},
+       {"\n//line a/b/c/File1.go:100\n  line100", filepath.Join("dir", "a", "b", "c", "File1.go"), 100},
+}
+
+var winsegments = []segment{
+       {"\n//line c:\\dir\\File1.go:100\n  line100", "c:\\dir\\File1.go", 100},
 }
 
 
 // Verify that comments of the form "//line filename:line" are interpreted correctly.
 func TestLineComments(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               segments = append(segments, winsegments...)
+       }
+
        // make source
        var src string
        for _, e := range segments {
@@ -477,7 +488,7 @@ func TestLineComments(t *testing.T) {
 
        // verify scan
        var S Scanner
-       file := fset.AddFile("dir/TestLineComments", fset.Base(), len(src))
+       file := fset.AddFile(filepath.Join("dir", "TestLineComments"), fset.Base(), len(src))
        S.Init(file, []byte(src), nil, 0)
        for _, s := range segments {
                p, _, lit := S.Scan()
index 114c93ea86ef65ff2e981c4f4dfae2ee2566e8fa..bd24f4ca42dc3d911a8a4a12897d4e8d77211f38 100644 (file)
@@ -2,15 +2,15 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// This file implements scope support functions.
+// DEPRECATED FILE - WILL GO AWAY EVENTUALLY.
+//
+// Scope handling is now done in go/parser.
+// The functionality here is only present to
+// keep the typechecker running for now.
 
 package typechecker
 
-import (
-       "fmt"
-       "go/ast"
-       "go/token"
-)
+import "go/ast"
 
 
 func (tc *typechecker) openScope() *ast.Scope {
@@ -24,52 +24,25 @@ func (tc *typechecker) closeScope() {
 }
 
 
-// objPos computes the source position of the declaration of an object name.
-// Only required for error reporting, so doesn't have to be fast.
-func objPos(obj *ast.Object) (pos token.Pos) {
-       switch d := obj.Decl.(type) {
-       case *ast.Field:
-               for _, n := range d.Names {
-                       if n.Name == obj.Name {
-                               return n.Pos()
-                       }
-               }
-       case *ast.ValueSpec:
-               for _, n := range d.Names {
-                       if n.Name == obj.Name {
-                               return n.Pos()
-                       }
-               }
-       case *ast.TypeSpec:
-               return d.Name.Pos()
-       case *ast.FuncDecl:
-               return d.Name.Pos()
-       }
-       if debug {
-               fmt.Printf("decl = %T\n", obj.Decl)
-       }
-       panic("unreachable")
-}
-
-
 // declInScope declares an object of a given kind and name in scope and sets the object's Decl and N fields.
 // It returns the newly allocated object. If an object with the same name already exists in scope, an error
 // is reported and the object is not inserted.
-// (Objects with _ name are always inserted into a scope without errors, but they cannot be found.)
-func (tc *typechecker) declInScope(scope *ast.Scope, kind ast.Kind, name *ast.Ident, decl interface{}, n int) *ast.Object {
+func (tc *typechecker) declInScope(scope *ast.Scope, kind ast.ObjKind, name *ast.Ident, decl interface{}, n int) *ast.Object {
        obj := ast.NewObj(kind, name.Name)
        obj.Decl = decl
-       obj.N = n
+       //obj.N = n
        name.Obj = obj
-       if alt := scope.Insert(obj); alt != obj {
-               tc.Errorf(name.Pos(), "%s already declared at %s", name.Name, objPos(alt))
+       if name.Name != "_" {
+               if alt := scope.Insert(obj); alt != obj {
+                       tc.Errorf(name.Pos(), "%s already declared at %s", name.Name, tc.fset.Position(alt.Pos()).String())
+               }
        }
        return obj
 }
 
 
 // decl is the same as declInScope(tc.topScope, ...)
-func (tc *typechecker) decl(kind ast.Kind, name *ast.Ident, decl interface{}, n int) *ast.Object {
+func (tc *typechecker) decl(kind ast.ObjKind, name *ast.Ident, decl interface{}, n int) *ast.Object {
        return tc.declInScope(tc.topScope, kind, name, decl, n)
 }
 
@@ -91,7 +64,7 @@ func (tc *typechecker) find(name *ast.Ident) (obj *ast.Object) {
 
 // findField returns the object with the given name if visible in the type's scope.
 // If no such object is found, an error is reported and a bad object is returned instead.
-func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Object) {
+func (tc *typechecker) findField(typ *Type, name *ast.Ident) (obj *ast.Object) {
        // TODO(gri) This is simplistic at the moment and ignores anonymous fields.
        obj = typ.Scope.Lookup(name.Name)
        if obj == nil {
@@ -100,20 +73,3 @@ func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Objec
        }
        return
 }
-
-
-// printScope prints the objects in a scope.
-func printScope(scope *ast.Scope) {
-       fmt.Printf("scope %p {", scope)
-       if scope != nil && len(scope.Objects) > 0 {
-               fmt.Println()
-               for _, obj := range scope.Objects {
-                       form := "void"
-                       if obj.Type != nil {
-                               form = obj.Type.Form.String()
-                       }
-                       fmt.Printf("\t%s\t%s\n", obj.Name, form)
-               }
-       }
-       fmt.Printf("}\n")
-}
diff --git a/libgo/go/go/typechecker/testdata/test0.go b/libgo/go/go/typechecker/testdata/test0.go
deleted file mode 100644 (file)
index 4e317f2..0000000
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// type declarations
-
-package P0
-
-type (
-       B bool
-       I int32
-       A [10]P
-       T struct {
-               x, y P
-       }
-       P *T
-       R *R
-       F func(A) I
-       Y interface {
-               f(A) I
-       }
-       S []P
-       M map[I]F
-       C chan<- I
-)
-
-type (
-       a/* ERROR "illegal cycle" */ a
-       a/* ERROR "already declared" */ int
-
-       b/* ERROR "illegal cycle" */ c
-       c d
-       d e
-       e b /* ERROR "not a type" */
-
-       t *t
-
-       U V
-       V W
-       W *U
-
-       P1 *S2
-       P2 P1
-
-       S1 struct {
-               a, b, c int
-               u, v, a/* ERROR "already declared" */ float
-       }
-       S2/* ERROR "illegal cycle" */ struct {
-               x S2
-       }
-
-       L1 []L1
-       L2 []int
-
-       A1 [10]int
-       A2/* ERROR "illegal cycle" */ [10]A2
-       A3/* ERROR "illegal cycle" */ [10]struct {
-               x A4
-       }
-       A4 [10]A3
-
-       F1 func()
-       F2 func(x, y, z float)
-       F3 func(x, y, x /* ERROR "already declared" */ float)
-       F4 func() (x, y, x /* ERROR "already declared" */ float)
-       F5 func(x int) (x /* ERROR "already declared" */ float)
-
-       I1 interface{}
-       I2 interface {
-               m1()
-       }
-       I3 interface {
-               m1()
-               m1 /* ERROR "already declared" */ ()
-       }
-       I4 interface {
-               m1(x, y, x /* ERROR "already declared" */ float)
-               m2() (x, y, x /* ERROR "already declared" */ float)
-               m3(x int) (x /* ERROR "already declared" */ float)
-       }
-       I5 interface {
-               m1(I5)
-       }
-
-       C1 chan int
-       C2 <-chan int
-       C3 chan<- C3
-
-       M1 map[Last]string
-       M2 map[string]M2
-
-       Last int
-)
diff --git a/libgo/go/go/typechecker/testdata/test0.src b/libgo/go/go/typechecker/testdata/test0.src
new file mode 100644 (file)
index 0000000..4e317f2
--- /dev/null
@@ -0,0 +1,94 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// type declarations
+
+package P0
+
+type (
+       B bool
+       I int32
+       A [10]P
+       T struct {
+               x, y P
+       }
+       P *T
+       R *R
+       F func(A) I
+       Y interface {
+               f(A) I
+       }
+       S []P
+       M map[I]F
+       C chan<- I
+)
+
+type (
+       a/* ERROR "illegal cycle" */ a
+       a/* ERROR "already declared" */ int
+
+       b/* ERROR "illegal cycle" */ c
+       c d
+       d e
+       e b /* ERROR "not a type" */
+
+       t *t
+
+       U V
+       V W
+       W *U
+
+       P1 *S2
+       P2 P1
+
+       S1 struct {
+               a, b, c int
+               u, v, a/* ERROR "already declared" */ float
+       }
+       S2/* ERROR "illegal cycle" */ struct {
+               x S2
+       }
+
+       L1 []L1
+       L2 []int
+
+       A1 [10]int
+       A2/* ERROR "illegal cycle" */ [10]A2
+       A3/* ERROR "illegal cycle" */ [10]struct {
+               x A4
+       }
+       A4 [10]A3
+
+       F1 func()
+       F2 func(x, y, z float)
+       F3 func(x, y, x /* ERROR "already declared" */ float)
+       F4 func() (x, y, x /* ERROR "already declared" */ float)
+       F5 func(x int) (x /* ERROR "already declared" */ float)
+
+       I1 interface{}
+       I2 interface {
+               m1()
+       }
+       I3 interface {
+               m1()
+               m1 /* ERROR "already declared" */ ()
+       }
+       I4 interface {
+               m1(x, y, x /* ERROR "already declared" */ float)
+               m2() (x, y, x /* ERROR "already declared" */ float)
+               m3(x int) (x /* ERROR "already declared" */ float)
+       }
+       I5 interface {
+               m1(I5)
+       }
+
+       C1 chan int
+       C2 <-chan int
+       C3 chan<- C3
+
+       M1 map[Last]string
+       M2 map[string]M2
+
+       Last int
+)
diff --git a/libgo/go/go/typechecker/testdata/test1.go b/libgo/go/go/typechecker/testdata/test1.go
deleted file mode 100644 (file)
index b0808ee..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// const and var declarations
-
-package P1
-
-const (
-       c1         /* ERROR "missing initializer" */
-       c2     int = 0
-       c3, c4 = 0
-)
diff --git a/libgo/go/go/typechecker/testdata/test1.src b/libgo/go/go/typechecker/testdata/test1.src
new file mode 100644 (file)
index 0000000..b5531fb
--- /dev/null
@@ -0,0 +1,13 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// const and var declarations
+
+package P1
+
+const (
+       c1 = 0
+       c2     int = 0
+       c3, c4 = 0
+)
diff --git a/libgo/go/go/typechecker/testdata/test3.go b/libgo/go/go/typechecker/testdata/test3.go
deleted file mode 100644 (file)
index ea35808..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package P3
-
-// function and method signatures
-
-func _()                                        {}
-func _()                                        {}
-func _(x, x /* ERROR "already declared" */ int) {}
-
-func f()                                 {}
-func f /* ERROR "already declared" */ () {}
-
-func (*foo /* ERROR "invalid receiver" */ ) m() {}
-func (bar /* ERROR "not a type" */ ) m()        {}
-
-func f1(x, _, _ int) (_, _ float)                              {}
-func f2(x, y, x /* ERROR "already declared" */ int)            {}
-func f3(x, y int) (a, b, x /* ERROR "already declared" */ int) {}
-
-func (x *T) m1()                                 {}
-func (x *T) m1 /* ERROR "already declared" */ () {}
-func (x T) m1 /* ERROR "already declared" */ ()  {}
-func (T) m1 /* ERROR "already declared" */ ()    {}
-
-func (x *T) m2(u, x /* ERROR "already declared" */ int)               {}
-func (x *T) m3(a, b, c int) (u, x /* ERROR "already declared" */ int) {}
-func (T) _(x, x /* ERROR "already declared" */ int)                   {}
-func (T) _() (x, x /* ERROR "already declared" */ int)                {}
-
-//func (PT) _() {}
-
-var bar int
-
-type T struct{}
-type PT (T)
diff --git a/libgo/go/go/typechecker/testdata/test3.src b/libgo/go/go/typechecker/testdata/test3.src
new file mode 100644 (file)
index 0000000..2e1a9fa
--- /dev/null
@@ -0,0 +1,41 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package P3
+
+// function and method signatures
+
+func _()                                        {}
+func _()                                        {}
+func _(x, x /* ERROR "already declared" */ int) {}
+
+func f()                                 {}
+func f /* ERROR "already declared" */ () {}
+
+func (*foo /* ERROR "invalid receiver" */ ) m() {}
+func (bar /* ERROR "not a type" */ ) m()        {}
+
+func f1(x, _, _ int) (_, _ float)                              {}
+func f2(x, y, x /* ERROR "already declared" */ int)            {}
+func f3(x, y int) (a, b, x /* ERROR "already declared" */ int) {}
+
+func (x *T) m1()                                 {}
+func (x *T) m1 /* ERROR "already declared" */ () {}
+func (x T) m1 /* ERROR "already declared" */ ()  {}
+func (T) m1 /* ERROR "already declared" */ ()    {}
+
+func (x *T) m2(u, x /* ERROR "already declared" */ int)               {}
+func (x *T) m3(a, b, c int) (u, x /* ERROR "already declared" */ int) {}
+// The following are disabled for now because the typechecker
+// in in the process of being rewritten and cannot handle them
+// at the moment
+//func (T) _(x, x /* "already declared" */ int)                   {}
+//func (T) _() (x, x /* "already declared" */ int)                {}
+
+//func (PT) _() {}
+
+var bar int
+
+type T struct{}
+type PT (T)
diff --git a/libgo/go/go/typechecker/testdata/test4.go b/libgo/go/go/typechecker/testdata/test4.go
deleted file mode 100644 (file)
index bb9aee3..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Constant declarations
-
-package P4
-
-const (
-       c0 /* ERROR "missing initializer" */
-)
diff --git a/libgo/go/go/typechecker/testdata/test4.src b/libgo/go/go/typechecker/testdata/test4.src
new file mode 100644 (file)
index 0000000..94d3558
--- /dev/null
@@ -0,0 +1,11 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Constant declarations
+
+package P4
+
+const (
+       c0 = 0
+)
diff --git a/libgo/go/go/typechecker/type.go b/libgo/go/go/typechecker/type.go
new file mode 100644 (file)
index 0000000..62b4e9d
--- /dev/null
@@ -0,0 +1,125 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package typechecker
+
+import "go/ast"
+
+
+// A Type represents a Go type.
+type Type struct {
+       Form     Form
+       Obj      *ast.Object // corresponding type name, or nil
+       Scope    *ast.Scope  // fields and methods, always present
+       N        uint        // basic type id, array length, number of function results, or channel direction
+       Key, Elt *Type       // map key and array, pointer, slice, map or channel element
+       Params   *ast.Scope  // function (receiver, input and result) parameters, tuple expressions (results of function calls), or nil
+       Expr     ast.Expr    // corresponding AST expression
+}
+
+
+// NewType creates a new type of a given form.
+func NewType(form Form) *Type {
+       return &Type{Form: form, Scope: ast.NewScope(nil)}
+}
+
+
+// Form describes the form of a type.
+type Form int
+
+// The list of possible type forms.
+const (
+       BadType    Form = iota // for error handling
+       Unresolved             // type not fully setup
+       Basic
+       Array
+       Struct
+       Pointer
+       Function
+       Method
+       Interface
+       Slice
+       Map
+       Channel
+       Tuple
+)
+
+
+var formStrings = [...]string{
+       BadType:    "badType",
+       Unresolved: "unresolved",
+       Basic:      "basic",
+       Array:      "array",
+       Struct:     "struct",
+       Pointer:    "pointer",
+       Function:   "function",
+       Method:     "method",
+       Interface:  "interface",
+       Slice:      "slice",
+       Map:        "map",
+       Channel:    "channel",
+       Tuple:      "tuple",
+}
+
+
+func (form Form) String() string { return formStrings[form] }
+
+
+// The list of basic type id's.
+const (
+       Bool = iota
+       Byte
+       Uint
+       Int
+       Float
+       Complex
+       Uintptr
+       String
+
+       Uint8
+       Uint16
+       Uint32
+       Uint64
+
+       Int8
+       Int16
+       Int32
+       Int64
+
+       Float32
+       Float64
+
+       Complex64
+       Complex128
+
+       // TODO(gri) ideal types are missing
+)
+
+
+var BasicTypes = map[uint]string{
+       Bool:    "bool",
+       Byte:    "byte",
+       Uint:    "uint",
+       Int:     "int",
+       Float:   "float",
+       Complex: "complex",
+       Uintptr: "uintptr",
+       String:  "string",
+
+       Uint8:  "uint8",
+       Uint16: "uint16",
+       Uint32: "uint32",
+       Uint64: "uint64",
+
+       Int8:  "int8",
+       Int16: "int16",
+       Int32: "int32",
+       Int64: "int64",
+
+       Float32: "float32",
+       Float64: "float64",
+
+       Complex64:  "complex64",
+       Complex128: "complex128",
+}
index e9aefa2402b0cc062268cf7c1eb4737706a30fc9..4fc5647f0d5d055ab3b8b3115a336e0e7914a219 100644 (file)
@@ -65,6 +65,7 @@ type typechecker struct {
        fset *token.FileSet
        scanner.ErrorVector
        importer Importer
+       globals  []*ast.Object        // list of global objects
        topScope *ast.Scope           // current top-most scope
        cyclemap map[*ast.Object]bool // for cycle detection
        iota     int                  // current value of iota
@@ -94,7 +95,7 @@ phase 1: declare all global objects; also collect all function and method declar
        - report global double declarations
 
 phase 2: bind methods to their receiver base types
-       - received base types must be declared in the package, thus for
+       - receiver base types must be declared in the package, thus for
          each method a corresponding (unresolved) type must exist
        - report method double declarations and errors with base types
 
@@ -142,16 +143,16 @@ func (tc *typechecker) checkPackage(pkg *ast.Package) {
        }
 
        // phase 3: resolve all global objects
-       // (note that objects with _ name are also in the scope)
        tc.cyclemap = make(map[*ast.Object]bool)
-       for _, obj := range tc.topScope.Objects {
+       for _, obj := range tc.globals {
                tc.resolve(obj)
        }
        assert(len(tc.cyclemap) == 0)
 
        // 4: sequentially typecheck function and method bodies
        for _, f := range funcs {
-               tc.checkBlock(f.Body.List, f.Name.Obj.Type)
+               ftype, _ := f.Name.Obj.Type.(*Type)
+               tc.checkBlock(f.Body.List, ftype)
        }
 
        pkg.Scope = tc.topScope
@@ -183,11 +184,11 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
                                                }
                                        }
                                        for _, name := range s.Names {
-                                               tc.decl(ast.Con, name, s, iota)
+                                               tc.globals = append(tc.globals, tc.decl(ast.Con, name, s, iota))
                                        }
                                case token.VAR:
                                        for _, name := range s.Names {
-                                               tc.decl(ast.Var, name, s, 0)
+                                               tc.globals = append(tc.globals, tc.decl(ast.Var, name, s, 0))
                                        }
                                default:
                                        panic("unreachable")
@@ -196,9 +197,10 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
                                iota++
                        case *ast.TypeSpec:
                                obj := tc.decl(ast.Typ, s.Name, s, 0)
+                               tc.globals = append(tc.globals, obj)
                                // give all type objects an unresolved type so
                                // that we can collect methods in the type scope
-                               typ := ast.NewType(ast.Unresolved)
+                               typ := NewType(Unresolved)
                                obj.Type = typ
                                typ.Obj = obj
                        default:
@@ -208,7 +210,7 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
 
        case *ast.FuncDecl:
                if d.Recv == nil {
-                       tc.decl(ast.Fun, d.Name, d, 0)
+                       tc.globals = append(tc.globals, tc.decl(ast.Fun, d.Name, d, 0))
                }
 
        default:
@@ -239,8 +241,8 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) {
                } else if obj.Kind != ast.Typ {
                        tc.Errorf(name.Pos(), "invalid receiver: %s is not a type", name.Name)
                } else {
-                       typ := obj.Type
-                       assert(typ.Form == ast.Unresolved)
+                       typ := obj.Type.(*Type)
+                       assert(typ.Form == Unresolved)
                        scope = typ.Scope
                }
        }
@@ -261,7 +263,7 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) {
 func (tc *typechecker) resolve(obj *ast.Object) {
        // check for declaration cycles
        if tc.cyclemap[obj] {
-               tc.Errorf(objPos(obj), "illegal cycle in declaration of %s", obj.Name)
+               tc.Errorf(obj.Pos(), "illegal cycle in declaration of %s", obj.Name)
                obj.Kind = ast.Bad
                return
        }
@@ -271,7 +273,7 @@ func (tc *typechecker) resolve(obj *ast.Object) {
        }()
 
        // resolve non-type objects
-       typ := obj.Type
+       typ, _ := obj.Type.(*Type)
        if typ == nil {
                switch obj.Kind {
                case ast.Bad:
@@ -282,12 +284,12 @@ func (tc *typechecker) resolve(obj *ast.Object) {
 
                case ast.Var:
                        tc.declVar(obj)
-                       //obj.Type = tc.typeFor(nil, obj.Decl.(*ast.ValueSpec).Type, false)
+                       obj.Type = tc.typeFor(nil, obj.Decl.(*ast.ValueSpec).Type, false)
 
                case ast.Fun:
-                       obj.Type = ast.NewType(ast.Function)
+                       obj.Type = NewType(Function)
                        t := obj.Decl.(*ast.FuncDecl).Type
-                       tc.declSignature(obj.Type, nil, t.Params, t.Results)
+                       tc.declSignature(obj.Type.(*Type), nil, t.Params, t.Results)
 
                default:
                        // type objects have non-nil types when resolve is called
@@ -300,32 +302,34 @@ func (tc *typechecker) resolve(obj *ast.Object) {
        }
 
        // resolve type objects
-       if typ.Form == ast.Unresolved {
+       if typ.Form == Unresolved {
                tc.typeFor(typ, typ.Obj.Decl.(*ast.TypeSpec).Type, false)
 
                // provide types for all methods
                for _, obj := range typ.Scope.Objects {
                        if obj.Kind == ast.Fun {
                                assert(obj.Type == nil)
-                               obj.Type = ast.NewType(ast.Method)
+                               obj.Type = NewType(Method)
                                f := obj.Decl.(*ast.FuncDecl)
                                t := f.Type
-                               tc.declSignature(obj.Type, f.Recv, t.Params, t.Results)
+                               tc.declSignature(obj.Type.(*Type), f.Recv, t.Params, t.Results)
                        }
                }
        }
 }
 
 
-func (tc *typechecker) checkBlock(body []ast.Stmt, ftype *ast.Type) {
+func (tc *typechecker) checkBlock(body []ast.Stmt, ftype *Type) {
        tc.openScope()
        defer tc.closeScope()
 
        // inject function/method parameters into block scope, if any
        if ftype != nil {
                for _, par := range ftype.Params.Objects {
-                       obj := tc.topScope.Insert(par)
-                       assert(obj == par) // ftype has no double declarations
+                       if par.Name != "_" {
+                               obj := tc.topScope.Insert(par)
+                               assert(obj == par) // ftype has no double declarations
+                       }
                }
        }
 
@@ -362,8 +366,8 @@ func (tc *typechecker) declFields(scope *ast.Scope, fields *ast.FieldList, ref b
 }
 
 
-func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.FieldList) {
-       assert((typ.Form == ast.Method) == (recv != nil))
+func (tc *typechecker) declSignature(typ *Type, recv, params, results *ast.FieldList) {
+       assert((typ.Form == Method) == (recv != nil))
        typ.Params = ast.NewScope(nil)
        tc.declFields(typ.Params, recv, true)
        tc.declFields(typ.Params, params, true)
@@ -371,7 +375,7 @@ func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.F
 }
 
 
-func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Type) {
+func (tc *typechecker) typeFor(def *Type, x ast.Expr, ref bool) (typ *Type) {
        x = unparen(x)
 
        // type name
@@ -381,10 +385,10 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
                if obj.Kind != ast.Typ {
                        tc.Errorf(t.Pos(), "%s is not a type", t.Name)
                        if def == nil {
-                               typ = ast.NewType(ast.BadType)
+                               typ = NewType(BadType)
                        } else {
                                typ = def
-                               typ.Form = ast.BadType
+                               typ.Form = BadType
                        }
                        typ.Expr = x
                        return
@@ -393,7 +397,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
                if !ref {
                        tc.resolve(obj) // check for cycles even if type resolved
                }
-               typ = obj.Type
+               typ = obj.Type.(*Type)
 
                if def != nil {
                        // new type declaration: copy type structure
@@ -410,7 +414,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
        // type literal
        typ = def
        if typ == nil {
-               typ = ast.NewType(ast.BadType)
+               typ = NewType(BadType)
        }
        typ.Expr = x
 
@@ -419,42 +423,42 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
                if debug {
                        fmt.Println("qualified identifier unimplemented")
                }
-               typ.Form = ast.BadType
+               typ.Form = BadType
 
        case *ast.StarExpr:
-               typ.Form = ast.Pointer
+               typ.Form = Pointer
                typ.Elt = tc.typeFor(nil, t.X, true)
 
        case *ast.ArrayType:
                if t.Len != nil {
-                       typ.Form = ast.Array
+                       typ.Form = Array
                        // TODO(gri) compute the real length
                        // (this may call resolve recursively)
                        (*typ).N = 42
                } else {
-                       typ.Form = ast.Slice
+                       typ.Form = Slice
                }
                typ.Elt = tc.typeFor(nil, t.Elt, t.Len == nil)
 
        case *ast.StructType:
-               typ.Form = ast.Struct
+               typ.Form = Struct
                tc.declFields(typ.Scope, t.Fields, false)
 
        case *ast.FuncType:
-               typ.Form = ast.Function
+               typ.Form = Function
                tc.declSignature(typ, nil, t.Params, t.Results)
 
        case *ast.InterfaceType:
-               typ.Form = ast.Interface
+               typ.Form = Interface
                tc.declFields(typ.Scope, t.Methods, true)
 
        case *ast.MapType:
-               typ.Form = ast.Map
+               typ.Form = Map
                typ.Key = tc.typeFor(nil, t.Key, true)
                typ.Elt = tc.typeFor(nil, t.Value, true)
 
        case *ast.ChanType:
-               typ.Form = ast.Channel
+               typ.Form = Channel
                typ.N = uint(t.Dir)
                typ.Elt = tc.typeFor(nil, t.Value, true)
 
index 33f4a6223ff41a62dc8d8618fb042590bc2ece94..3988ff1680b9bdac9929f1658d9651dda4a733d1 100644 (file)
@@ -93,7 +93,7 @@ func expectedErrors(t *testing.T, pkg *ast.Package) (list scanner.ErrorList) {
 
 
 func testFilter(f *os.FileInfo) bool {
-       return strings.HasSuffix(f.Name, ".go") && f.Name[0] != '.'
+       return strings.HasSuffix(f.Name, ".src") && f.Name[0] != '.'
 }
 
 
index db950737f39f75362cf9397fc552ebe93f4328a4..cf4434993e126363aeff7029747489b2f37280f7 100644 (file)
@@ -24,8 +24,8 @@ func init() {
        Universe = ast.NewScope(nil)
 
        // basic types
-       for n, name := range ast.BasicTypes {
-               typ := ast.NewType(ast.Basic)
+       for n, name := range BasicTypes {
+               typ := NewType(Basic)
                typ.N = n
                obj := ast.NewObj(ast.Typ, name)
                obj.Type = typ
index c822d6863ac6eb868a92b6156e52cd1598138447..28042ccaa3a2db8a815eaa9ac079797fd6207eca 100644 (file)
@@ -50,7 +50,7 @@ func testError(t *testing.T) {
 func TestUintCodec(t *testing.T) {
        defer testError(t)
        b := new(bytes.Buffer)
-       encState := newEncoderState(nil, b)
+       encState := newEncoderState(b)
        for _, tt := range encodeT {
                b.Reset()
                encState.encodeUint(tt.x)
@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) {
                        t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
                }
        }
-       decState := newDecodeState(nil, b)
+       decState := newDecodeState(b)
        for u := uint64(0); ; u = (u + 1) * 7 {
                b.Reset()
                encState.encodeUint(u)
@@ -75,9 +75,9 @@ func TestUintCodec(t *testing.T) {
 func verifyInt(i int64, t *testing.T) {
        defer testError(t)
        var b = new(bytes.Buffer)
-       encState := newEncoderState(nil, b)
+       encState := newEncoderState(b)
        encState.encodeInt(i)
-       decState := newDecodeState(nil, b)
+       decState := newDecodeState(b)
        decState.buf = make([]byte, 8)
        j := decState.decodeInt()
        if i != j {
@@ -111,9 +111,16 @@ var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40}
 // The result of encoding "hello" with field number 7
 var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'}
 
-func newencoderState(b *bytes.Buffer) *encoderState {
+func newDecodeState(buf *bytes.Buffer) *decoderState {
+       d := new(decoderState)
+       d.b = buf
+       d.buf = make([]byte, uint64Size)
+       return d
+}
+
+func newEncoderState(b *bytes.Buffer) *encoderState {
        b.Reset()
-       state := newEncoderState(nil, b)
+       state := &encoderState{enc: nil, b: b}
        state.fieldnum = -1
        return state
 }
@@ -127,7 +134,7 @@ func TestScalarEncInstructions(t *testing.T) {
        {
                data := struct{ a bool }{true}
                instr := &encInstr{encBool, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(boolResult, b.Bytes()) {
                        t.Errorf("bool enc instructions: expected % x got % x", boolResult, b.Bytes())
@@ -139,7 +146,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a int }{17}
                instr := &encInstr{encInt, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(signedResult, b.Bytes()) {
                        t.Errorf("int enc instructions: expected % x got % x", signedResult, b.Bytes())
@@ -151,7 +158,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a uint }{17}
                instr := &encInstr{encUint, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(unsignedResult, b.Bytes()) {
                        t.Errorf("uint enc instructions: expected % x got % x", unsignedResult, b.Bytes())
@@ -163,7 +170,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a int8 }{17}
                instr := &encInstr{encInt8, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(signedResult, b.Bytes()) {
                        t.Errorf("int8 enc instructions: expected % x got % x", signedResult, b.Bytes())
@@ -175,7 +182,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a uint8 }{17}
                instr := &encInstr{encUint8, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(unsignedResult, b.Bytes()) {
                        t.Errorf("uint8 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
@@ -187,7 +194,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a int16 }{17}
                instr := &encInstr{encInt16, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(signedResult, b.Bytes()) {
                        t.Errorf("int16 enc instructions: expected % x got % x", signedResult, b.Bytes())
@@ -199,7 +206,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a uint16 }{17}
                instr := &encInstr{encUint16, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(unsignedResult, b.Bytes()) {
                        t.Errorf("uint16 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
@@ -211,7 +218,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a int32 }{17}
                instr := &encInstr{encInt32, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(signedResult, b.Bytes()) {
                        t.Errorf("int32 enc instructions: expected % x got % x", signedResult, b.Bytes())
@@ -223,7 +230,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a uint32 }{17}
                instr := &encInstr{encUint32, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(unsignedResult, b.Bytes()) {
                        t.Errorf("uint32 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
@@ -235,7 +242,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a int64 }{17}
                instr := &encInstr{encInt64, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(signedResult, b.Bytes()) {
                        t.Errorf("int64 enc instructions: expected % x got % x", signedResult, b.Bytes())
@@ -247,7 +254,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a uint64 }{17}
                instr := &encInstr{encUint64, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(unsignedResult, b.Bytes()) {
                        t.Errorf("uint64 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
@@ -259,7 +266,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a float32 }{17}
                instr := &encInstr{encFloat32, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(floatResult, b.Bytes()) {
                        t.Errorf("float32 enc instructions: expected % x got % x", floatResult, b.Bytes())
@@ -271,7 +278,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a float64 }{17}
                instr := &encInstr{encFloat64, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(floatResult, b.Bytes()) {
                        t.Errorf("float64 enc instructions: expected % x got % x", floatResult, b.Bytes())
@@ -283,7 +290,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a []byte }{[]byte("hello")}
                instr := &encInstr{encUint8Array, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(bytesResult, b.Bytes()) {
                        t.Errorf("bytes enc instructions: expected % x got % x", bytesResult, b.Bytes())
@@ -295,7 +302,7 @@ func TestScalarEncInstructions(t *testing.T) {
                b.Reset()
                data := struct{ a string }{"hello"}
                instr := &encInstr{encString, 6, 0, 0}
-               state := newencoderState(b)
+               state := newEncoderState(b)
                instr.op(instr, state, unsafe.Pointer(&data))
                if !bytes.Equal(bytesResult, b.Bytes()) {
                        t.Errorf("string enc instructions: expected % x got % x", bytesResult, b.Bytes())
@@ -303,7 +310,7 @@ func TestScalarEncInstructions(t *testing.T) {
        }
 }
 
-func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) {
+func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) {
        defer testError(t)
        v := int(state.decodeUint())
        if v+state.fieldnum != 6 {
@@ -313,9 +320,9 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
        state.fieldnum = 6
 }
 
-func newDecodeStateFromData(data []byte) *decodeState {
+func newDecodeStateFromData(data []byte) *decoderState {
        b := bytes.NewBuffer(data)
-       state := newDecodeState(nil, b)
+       state := newDecodeState(b)
        state.fieldnum = -1
        return state
 }
@@ -997,9 +1004,9 @@ func TestInvalidField(t *testing.T) {
        var bad0 Bad0
        bad0.CH = make(chan int)
        b := new(bytes.Buffer)
-       var nilEncoder *Encoder
-       err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
-       if err == nil {
+       dummyEncoder := new(Encoder) // sufficient for this purpose.
+       dummyEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
+       if err := dummyEncoder.err; err == nil {
                t.Error("expected error; got none")
        } else if strings.Index(err.String(), "type") < 0 {
                t.Error("expected type error; got", err)
index 8f599e100413c1f371da25d0c52b5a49786aeb7a..f8159d4ea32357fd20a69524ed9a89b2593eedc5 100644 (file)
@@ -13,38 +13,47 @@ import (
        "math"
        "os"
        "reflect"
-       "unicode"
        "unsafe"
-       "utf8"
 )
 
 var (
        errBadUint = os.ErrorString("gob: encoded unsigned integer out of range")
        errBadType = os.ErrorString("gob: unknown type id or corrupted data")
-       errRange   = os.ErrorString("gob: internal error: field numbers out of bounds")
+       errRange   = os.ErrorString("gob: bad data: field numbers out of bounds")
 )
 
-// The execution state of an instance of the decoder. A new state
+// decoderState is the execution state of an instance of the decoder. A new state
 // is created for nested objects.
-type decodeState struct {
+type decoderState struct {
        dec *Decoder
        // The buffer is stored with an extra indirection because it may be replaced
        // if we load a type during decode (when reading an interface value).
        b        *bytes.Buffer
        fieldnum int // the last field number read.
        buf      []byte
+       next     *decoderState // for free list
 }
 
 // We pass the bytes.Buffer separately for easier testing of the infrastructure
 // without requiring a full Decoder.
-func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState {
-       d := new(decodeState)
-       d.dec = dec
+func (dec *Decoder) newDecoderState(buf *bytes.Buffer) *decoderState {
+       d := dec.freeList
+       if d == nil {
+               d = new(decoderState)
+               d.dec = dec
+               d.buf = make([]byte, uint64Size)
+       } else {
+               dec.freeList = d.next
+       }
        d.b = buf
-       d.buf = make([]byte, uint64Size)
        return d
 }
 
+func (dec *Decoder) freeDecoderState(d *decoderState) {
+       d.next = dec.freeList
+       dec.freeList = d
+}
+
 func overflow(name string) os.ErrorString {
        return os.ErrorString(`value for "` + name + `" out of range`)
 }
@@ -85,7 +94,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Erro
 
 // decodeUint reads an encoded unsigned integer from state.r.
 // Does not check for overflow.
-func (state *decodeState) decodeUint() (x uint64) {
+func (state *decoderState) decodeUint() (x uint64) {
        b, err := state.b.ReadByte()
        if err != nil {
                error(err)
@@ -112,7 +121,7 @@ func (state *decodeState) decodeUint() (x uint64) {
 
 // decodeInt reads an encoded signed integer from state.r.
 // Does not check for overflow.
-func (state *decodeState) decodeInt() int64 {
+func (state *decoderState) decodeInt() int64 {
        x := state.decodeUint()
        if x&1 != 0 {
                return ^int64(x >> 1)
@@ -120,7 +129,8 @@ func (state *decodeState) decodeInt() int64 {
        return int64(x >> 1)
 }
 
-type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer)
+// decOp is the signature of a decoding operator for a given type.
+type decOp func(i *decInstr, state *decoderState, p unsafe.Pointer)
 
 // The 'instructions' of the decoding machine
 type decInstr struct {
@@ -150,26 +160,31 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
        return p
 }
 
-func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// ignoreUint discards a uint value with no destination.
+func ignoreUint(i *decInstr, state *decoderState, p unsafe.Pointer) {
        state.decodeUint()
 }
 
-func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// ignoreTwoUints discards a uint value with no destination. It's used to skip
+// complex values.
+func ignoreTwoUints(i *decInstr, state *decoderState, p unsafe.Pointer) {
        state.decodeUint()
        state.decodeUint()
 }
 
-func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decBool decodes a uiint and stores it as a boolean through p.
+func decBool(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool))
                }
                p = *(*unsafe.Pointer)(p)
        }
-       *(*bool)(p) = state.decodeInt() != 0
+       *(*bool)(p) = state.decodeUint() != 0
 }
 
-func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt8 decodes an integer and stores it as an int8 through p.
+func decInt8(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8))
@@ -184,7 +199,8 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint8 decodes an unsigned integer and stores it as a uint8 through p.
+func decUint8(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8))
@@ -199,7 +215,8 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt16 decodes an integer and stores it as an int16 through p.
+func decInt16(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16))
@@ -214,7 +231,8 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint16 decodes an unsigned integer and stores it as a uint16 through p.
+func decUint16(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16))
@@ -229,7 +247,8 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt32 decodes an integer and stores it as an int32 through p.
+func decInt32(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32))
@@ -244,7 +263,8 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint32 decodes an unsigned integer and stores it as a uint32 through p.
+func decUint32(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32))
@@ -259,7 +279,8 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt64 decodes an integer and stores it as an int64 through p.
+func decInt64(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64))
@@ -269,7 +290,8 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) {
        *(*int64)(p) = int64(state.decodeInt())
 }
 
-func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint64 decodes an unsigned integer and stores it as a uint64 through p.
+func decUint64(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64))
@@ -294,7 +316,9 @@ func floatFromBits(u uint64) float64 {
        return math.Float64frombits(v)
 }
 
-func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// storeFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
+// number, and stores it through p. It's a helper function for float32 and complex64.
+func storeFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) {
        v := floatFromBits(state.decodeUint())
        av := v
        if av < 0 {
@@ -308,7 +332,9 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
 }
 
-func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
+// number, and stores it through p.
+func decFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32))
@@ -318,7 +344,9 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        storeFloat32(i, state, p)
 }
 
-func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point
+// number, and stores it through p.
+func decFloat64(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64))
@@ -328,8 +356,10 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
        *(*float64)(p) = floatFromBits(uint64(state.decodeUint()))
 }
 
-// Complex numbers are just a pair of floating-point numbers, real part first.
-func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decComplex64 decodes a pair of unsigned integers, treats them as a
+// pair of floating point numbers, and stores them as a complex64 through p.
+// The real part comes first.
+func decComplex64(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64))
@@ -340,7 +370,10 @@ func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) {
        storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float32(0)))))
 }
 
-func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decComplex128 decodes a pair of unsigned integers, treats them as a
+// pair of floating point numbers, and stores them as a complex128 through p.
+// The real part comes first.
+func decComplex128(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128))
@@ -352,8 +385,10 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) {
        *(*complex128)(p) = complex(real, imag)
 }
 
+// decUint8Array decodes byte array and stores through p a slice header
+// describing the data.
 // uint8 arrays are encoded as an unsigned count followed by the raw bytes.
-func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
+func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8))
@@ -365,8 +400,10 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
        *(*[]uint8)(p) = b
 }
 
+// decString decodes byte array and stores through p a string header
+// describing the data.
 // Strings are encoded as an unsigned count followed by the raw bytes.
-func decString(i *decInstr, state *decodeState, p unsafe.Pointer) {
+func decString(i *decInstr, state *decoderState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
                        *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte))
@@ -375,10 +412,18 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        b := make([]byte, state.decodeUint())
        state.b.Read(b)
-       *(*string)(p) = string(b)
+       // It would be a shame to do the obvious thing here,
+       //      *(*string)(p) = string(b)
+       // because we've already allocated the storage and this would
+       // allocate again and copy.  So we do this ugly hack, which is even
+       // even more unsafe than it looks as it depends the memory
+       // representation of a string matching the beginning of the memory
+       // representation of a byte slice (a byte slice is longer).
+       *(*string)(p) = *(*string)(unsafe.Pointer(&b))
 }
 
-func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// ignoreUint8Array skips over the data for a byte slice value with no destination.
+func ignoreUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) {
        b := make([]byte, state.decodeUint())
        state.b.Read(b)
 }
@@ -409,9 +454,16 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
        return *(*uintptr)(up)
 }
 
+// decodeSingle decodes a top-level value that is not a struct and stores it through p.
+// Such values are preceded by a zero, making them have the memory layout of a
+// struct field (although with an illegal field number).
 func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
-       p = allocate(ut.base, p, ut.indir)
-       state := newDecodeState(dec, &dec.buf)
+       indir := ut.indir
+       if ut.isGobDecoder {
+               indir = int(ut.decIndir)
+       }
+       p = allocate(ut.base, p, indir)
+       state := dec.newDecoderState(&dec.buf)
        state.fieldnum = singletonField
        basep := p
        delta := int(state.decodeUint())
@@ -424,16 +476,18 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
                ptr = decIndirect(ptr, instr.indir)
        }
        instr.op(instr, state, ptr)
+       dec.freeDecoderState(state)
        return nil
 }
 
+// decodeSingle decodes a top-level struct and stores it through p.
 // Indir is for the value, not the type.  At the time of the call it may
 // differ from ut.indir, which was computed when the engine was built.
 // This state cannot arise for decodeSingle, which is called directly
 // from the user's value, not from the innards of an engine.
-func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
+func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) {
        p = allocate(ut.base.(*reflect.StructType), p, indir)
-       state := newDecodeState(dec, &dec.buf)
+       state := dec.newDecoderState(&dec.buf)
        state.fieldnum = -1
        basep := p
        for state.b.Len() > 0 {
@@ -457,11 +511,12 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
                instr.op(instr, state, p)
                state.fieldnum = fieldnum
        }
-       return nil
+       dec.freeDecoderState(state)
 }
 
-func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
-       state := newDecodeState(dec, &dec.buf)
+// ignoreStruct discards the data for a struct with no destination.
+func (dec *Decoder) ignoreStruct(engine *decEngine) {
+       state := dec.newDecoderState(&dec.buf)
        state.fieldnum = -1
        for state.b.Len() > 0 {
                delta := int(state.decodeUint())
@@ -479,11 +534,13 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
                instr.op(instr, state, unsafe.Pointer(nil))
                state.fieldnum = fieldnum
        }
-       return nil
+       dec.freeDecoderState(state)
 }
 
-func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
-       state := newDecodeState(dec, &dec.buf)
+// ignoreSingle discards the data for a top-level non-struct value with no
+// destination. It's used when calling Decode with a nil value.
+func (dec *Decoder) ignoreSingle(engine *decEngine) {
+       state := dec.newDecoderState(&dec.buf)
        state.fieldnum = singletonField
        delta := int(state.decodeUint())
        if delta != 0 {
@@ -491,10 +548,11 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
        }
        instr := &engine.instr[singletonField]
        instr.op(instr, state, unsafe.Pointer(nil))
-       return nil
+       dec.freeDecoderState(state)
 }
 
-func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) {
+// decodeArrayHelper does the work for decoding arrays and slices.
+func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) {
        instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl}
        for i := 0; i < length; i++ {
                up := unsafe.Pointer(p)
@@ -506,7 +564,10 @@ func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decO
        }
 }
 
-func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) {
+// decodeArray decodes an array and stores it through p, that is, p points to the zeroth element.
+// The length is an unsigned integer preceding the elements.  Even though the length is redundant
+// (it's part of the type), it's a useful check and is included in the encoding.
+func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) {
        if indir > 0 {
                p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect
        }
@@ -516,7 +577,9 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u
        dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
 }
 
-func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
+// decodeIntoValue is a helper for map decoding.  Since maps are decoded using reflection,
+// unlike the other items we can't use a pointer directly.
+func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
        instr := &decInstr{op, 0, indir, 0, ovfl}
        up := unsafe.Pointer(v.UnsafeAddr())
        if indir > 1 {
@@ -526,7 +589,11 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o
        return v
 }
 
-func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) {
+// decodeMap decodes a map and stores its header through p.
+// Maps are encoded as a length followed by key:value pairs.
+// Because the internals of maps are not visible to us, we must
+// use reflection rather than pointer magic.
+func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) {
        if indir > 0 {
                p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect
        }
@@ -538,7 +605,7 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp
        // Maps cannot be accessed by moving addresses around the way
        // that slices etc. can.  We must recover a full reflection value for
        // the iteration.
-       v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue)
+       v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))).(*reflect.MapValue)
        n := int(state.decodeUint())
        for i := 0; i < n; i++ {
                key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
@@ -547,21 +614,24 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp
        }
 }
 
-func (dec *Decoder) ignoreArrayHelper(state *decodeState, elemOp decOp, length int) {
+// ignoreArrayHelper does the work for discarding arrays and slices.
+func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) {
        instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
        for i := 0; i < length; i++ {
                elemOp(instr, state, nil)
        }
 }
 
-func (dec *Decoder) ignoreArray(state *decodeState, elemOp decOp, length int) {
+// ignoreArray discards the data for an array value with no destination.
+func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) {
        if n := state.decodeUint(); n != uint64(length) {
                errorf("gob: length mismatch in ignoreArray")
        }
        dec.ignoreArrayHelper(state, elemOp, length)
 }
 
-func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) {
+// ignoreMap discards the data for a map value with no destination.
+func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
        n := int(state.decodeUint())
        keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
        elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
@@ -571,7 +641,9 @@ func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) {
        }
 }
 
-func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) {
+// decodeSlice decodes a slice and stores the slice header through p.
+// Slices are encoded as an unsigned length followed by the elements.
+func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) {
        n := int(uintptr(state.decodeUint()))
        if indir > 0 {
                up := unsafe.Pointer(p)
@@ -590,7 +662,8 @@ func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p u
        dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl)
 }
 
-func (dec *Decoder) ignoreSlice(state *decodeState, elemOp decOp) {
+// ignoreSlice skips over the data for a slice value with no destination.
+func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) {
        dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint()))
 }
 
@@ -609,9 +682,10 @@ func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) {
        ivalue.Set(value)
 }
 
-// decodeInterface receives the name of a concrete type followed by its value.
+// decodeInterface decodes an interface value and stores it through p.
+// Interfaces are encoded as the name of a concrete type followed by a value.
 // If the name is empty, the value is nil and no value is sent.
-func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) {
+func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderState, p uintptr, indir int) {
        // Create an interface reflect.Value.  We need one even for the nil case.
        ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue)
        // Read the name of the concrete type.
@@ -655,7 +729,8 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
        *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
 }
 
-func (dec *Decoder) ignoreInterface(state *decodeState) {
+// ignoreInterface discards the data for an interface value with no destination.
+func (dec *Decoder) ignoreInterface(state *decoderState) {
        // Read the name of the concrete type.
        b := make([]byte, state.decodeUint())
        _, err := state.b.Read(b)
@@ -670,6 +745,32 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
        state.b.Next(int(state.decodeUint()))
 }
 
+// decodeGobDecoder decodes something implementing the GobDecoder interface.
+// The data is encoded as a byte slice.
+func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value, index int) {
+       // Read the bytes for the value.
+       b := make([]byte, state.decodeUint())
+       _, err := state.b.Read(b)
+       if err != nil {
+               error(err)
+       }
+       // We know it's a GobDecoder, so just call the method directly.
+       err = v.Interface().(GobDecoder).GobDecode(b)
+       if err != nil {
+               error(err)
+       }
+}
+
+// ignoreGobDecoder discards the data for a GobDecoder value with no destination.
+func (dec *Decoder) ignoreGobDecoder(state *decoderState) {
+       // Read the bytes for the value.
+       b := make([]byte, state.decodeUint())
+       _, err := state.b.Read(b)
+       if err != nil {
+               error(err)
+       }
+}
+
 // Index by Go types.
 var decOpTable = [...]decOp{
        reflect.Bool:       decBool,
@@ -699,10 +800,14 @@ var decIgnoreOpMap = map[typeId]decOp{
        tComplex: ignoreTwoUints,
 }
 
-// Return the decoding op for the base type under rt and
+// decOpFor returns the decoding op for the base type under rt and
 // the indirection count to reach it.
 func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
        ut := userType(rt)
+       // If the type implements GobEncoder, we handle it without further processing.
+       if ut.isGobDecoder {
+               return dec.gobDecodeOpFor(ut)
+       }
        // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
        // Return the pointer to the op we're already building.
        if opPtr := inProgress[rt]; opPtr != nil {
@@ -724,7 +829,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
                        elemId := dec.wireType[wireId].ArrayT.Elem
                        elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
                        ovfl := overflow(name)
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
                        }
 
@@ -735,7 +840,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
                        keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
                        elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
                        ovfl := overflow(name)
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                up := unsafe.Pointer(p)
                                state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
                        }
@@ -754,26 +859,23 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
                        }
                        elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
                        ovfl := overflow(name)
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
                        }
 
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
-                       enginePtr, err := dec.getDecEnginePtr(wireId, typ)
+                       enginePtr, err := dec.getDecEnginePtr(wireId, userType(typ))
                        if err != nil {
                                error(err)
                        }
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                // indirect through enginePtr to delay evaluation for recursive structs.
-                               err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir)
-                               if err != nil {
-                                       error(err)
-                               }
+                               dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir)
                        }
                case *reflect.InterfaceType:
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               dec.decodeInterface(t, state, uintptr(p), i.indir)
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+                               state.dec.decodeInterface(t, state, uintptr(p), i.indir)
                        }
                }
        }
@@ -783,15 +885,15 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
        return &op, indir
 }
 
-// Return the decoding op for a field that has no destination.
+// decIgnoreOpFor returns the decoding op for a field that has no destination.
 func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
        op, ok := decIgnoreOpMap[wireId]
        if !ok {
                if wireId == tInterface {
                        // Special case because it's a method: the ignored item might
                        // define types and we need to record their state in the decoder.
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               dec.ignoreInterface(state)
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+                               state.dec.ignoreInterface(state)
                        }
                        return op
                }
@@ -799,11 +901,11 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
                wire := dec.wireType[wireId]
                switch {
                case wire == nil:
-                       panic("internal error: can't find ignore op for type " + wireId.string())
+                       errorf("gob: bad data: undefined type %s", wireId.string())
                case wire.ArrayT != nil:
                        elemId := wire.ArrayT.Elem
                        elemOp := dec.decIgnoreOpFor(elemId)
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len)
                        }
 
@@ -812,14 +914,14 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
                        elemId := dec.wireType[wireId].MapT.Elem
                        keyOp := dec.decIgnoreOpFor(keyId)
                        elemOp := dec.decIgnoreOpFor(elemId)
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                state.dec.ignoreMap(state, keyOp, elemOp)
                        }
 
                case wire.SliceT != nil:
                        elemId := wire.SliceT.Elem
                        elemOp := dec.decIgnoreOpFor(elemId)
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                state.dec.ignoreSlice(state, elemOp)
                        }
 
@@ -829,28 +931,75 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
                        if err != nil {
                                error(err)
                        }
-                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
                                // indirect through enginePtr to delay evaluation for recursive structs
                                state.dec.ignoreStruct(*enginePtr)
                        }
+
+               case wire.GobEncoderT != nil:
+                       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+                               state.dec.ignoreGobDecoder(state)
+                       }
                }
        }
        if op == nil {
-               errorf("ignore can't handle type %s", wireId.string())
+               errorf("gob: bad data: ignore can't handle type %s", wireId.string())
        }
        return op
 }
 
-// Are these two gob Types compatible?
-// Answers the question for basic types, arrays, and slices.
+// gobDecodeOpFor returns the op for a type that is known to implement
+// GobDecoder.
+func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) {
+       rt := ut.user
+       if ut.decIndir == -1 {
+               rt = reflect.PtrTo(rt)
+       } else if ut.decIndir > 0 {
+               for i := int8(0); i < ut.decIndir; i++ {
+                       rt = rt.(*reflect.PtrType).Elem()
+               }
+       }
+       var op decOp
+       op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+               // Allocate the underlying data, but hold on to the address we have,
+               // since we need it to get to the receiver's address.
+               allocate(ut.base, uintptr(p), ut.indir)
+               var v reflect.Value
+               if ut.decIndir == -1 {
+                       // Need to climb up one level to turn value into pointer.
+                       v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p)))
+               } else {
+                       if ut.decIndir > 0 {
+                               p = decIndirect(p, int(ut.decIndir))
+                       }
+                       v = reflect.NewValue(unsafe.Unreflect(rt, p))
+               }
+               state.dec.decodeGobDecoder(state, v, methodIndex(rt, gobDecodeMethodName))
+       }
+       return &op, int(ut.decIndir)
+
+}
+
+// compatibleType asks: Are these two gob Types compatible?
+// Answers the question for basic types, arrays, maps and slices, plus
+// GobEncoder/Decoder pairs.
 // Structs are considered ok; fields will be checked later.
 func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
        if rhs, ok := inProgress[fr]; ok {
                return rhs == fw
        }
        inProgress[fr] = fw
-       fr = userType(fr).base
-       switch t := fr.(type) {
+       ut := userType(fr)
+       wire, ok := dec.wireType[fw]
+       // If fr is a GobDecoder, the wire type must be GobEncoder.
+       // And if fr is not a GobDecoder, the wire type must not be either.
+       if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct.
+               return false
+       }
+       if ut.isGobDecoder { // This test trumps all others.
+               return true
+       }
+       switch t := ut.base.(type) {
        default:
                // chan, etc: cannot handle.
                return false
@@ -869,14 +1018,12 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
        case *reflect.InterfaceType:
                return fw == tInterface
        case *reflect.ArrayType:
-               wire, ok := dec.wireType[fw]
                if !ok || wire.ArrayT == nil {
                        return false
                }
                array := wire.ArrayT
                return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
        case *reflect.MapType:
-               wire, ok := dec.wireType[fw]
                if !ok || wire.MapT == nil {
                        return false
                }
@@ -911,8 +1058,13 @@ func (dec *Decoder) typeString(remoteId typeId) string {
        return dec.wireType[remoteId].string()
 }
 
-
-func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
+// compileSingle compiles the decoder engine for a non-struct top-level value, including
+// GobDecoders.
+func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) {
+       rt := ut.base
+       if ut.isGobDecoder {
+               rt = ut.user
+       }
        engine = new(decEngine)
        engine.instr = make([]decInstr, 1) // one item
        name := rt.String()                // best we can do
@@ -926,6 +1078,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
        return
 }
 
+// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded.
 func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) {
        engine = new(decEngine)
        engine.instr = make([]decInstr, 1) // one item
@@ -936,16 +1089,13 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err
        return
 }
 
-// Is this an exported - upper case - name?
-func isExported(name string) bool {
-       rune, _ := utf8.DecodeRuneInString(name)
-       return unicode.IsUpper(rune)
-}
-
-func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
+// compileDec compiles the decoder engine for a value.  If the value is not a struct,
+// it calls out to compileSingle.
+func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) {
+       rt := ut.base
        srt, ok := rt.(*reflect.StructType)
-       if !ok {
-               return dec.compileSingle(remoteId, rt)
+       if !ok || ut.isGobDecoder {
+               return dec.compileSingle(remoteId, ut)
        }
        var wireStruct *structType
        // Builtin types can come from global pool; the rest must be defined by the decoder.
@@ -990,7 +1140,9 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
        return
 }
 
-func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
+// getDecEnginePtr returns the engine for the specified type.
+func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err os.Error) {
+       rt := ut.base
        decoderMap, ok := dec.decoderCache[rt]
        if !ok {
                decoderMap = make(map[typeId]**decEngine)
@@ -1000,7 +1152,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr
                // To handle recursive types, mark this engine as underway before compiling.
                enginePtr = new(*decEngine)
                decoderMap[remoteId] = enginePtr
-               *enginePtr, err = dec.compileDec(remoteId, rt)
+               *enginePtr, err = dec.compileDec(remoteId, ut)
                if err != nil {
                        decoderMap[remoteId] = nil, false
                }
@@ -1008,11 +1160,12 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr
        return
 }
 
-// When ignoring struct data, in effect we compile it into this type
+// emptyStruct is the type we compile into when ignoring a struct value.
 type emptyStruct struct{}
 
 var emptyStructType = reflect.Typeof(emptyStruct{})
 
+// getDecEnginePtr returns the engine for the specified type when the value is to be discarded.
 func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
        var ok bool
        if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
@@ -1021,7 +1174,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
                dec.ignorerCache[wireId] = enginePtr
                wire := dec.wireType[wireId]
                if wire != nil && wire.StructT != nil {
-                       *enginePtr, err = dec.compileDec(wireId, emptyStructType)
+                       *enginePtr, err = dec.compileDec(wireId, userType(emptyStructType))
                } else {
                        *enginePtr, err = dec.compileIgnoreSingle(wireId)
                }
@@ -1032,41 +1185,51 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
        return
 }
 
-func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
-       defer catchError(&err)
+// decodeValue decodes the data stream representing a value and stores it in val.
+func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) {
+       defer catchError(&dec.err)
        // If the value is nil, it means we should just ignore this item.
        if val == nil {
-               return dec.decodeIgnoredValue(wireId)
+               dec.decodeIgnoredValue(wireId)
+               return
        }
        // Dereference down to the underlying struct type.
        ut := userType(val.Type())
        base := ut.base
        indir := ut.indir
-       enginePtr, err := dec.getDecEnginePtr(wireId, base)
-       if err != nil {
-               return err
+       if ut.isGobDecoder {
+               indir = int(ut.decIndir)
+       }
+       var enginePtr **decEngine
+       enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut)
+       if dec.err != nil {
+               return
        }
        engine := *enginePtr
-       if st, ok := base.(*reflect.StructType); ok {
+       if st, ok := base.(*reflect.StructType); ok && !ut.isGobDecoder {
                if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 {
                        name := base.Name()
-                       return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
+                       errorf("gob: type mismatch: no fields matched compiling decoder for %s", name)
                }
-               return dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir)
+               dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir)
+       } else {
+               dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr()))
        }
-       return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr()))
 }
 
-func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error {
-       enginePtr, err := dec.getIgnoreEnginePtr(wireId)
-       if err != nil {
-               return err
+// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it.
+func (dec *Decoder) decodeIgnoredValue(wireId typeId) {
+       var enginePtr **decEngine
+       enginePtr, dec.err = dec.getIgnoreEnginePtr(wireId)
+       if dec.err != nil {
+               return
        }
        wire := dec.wireType[wireId]
        if wire != nil && wire.StructT != nil {
-               return dec.ignoreStruct(*enginePtr)
+               dec.ignoreStruct(*enginePtr)
+       } else {
+               dec.ignoreSingle(*enginePtr)
        }
-       return dec.ignoreSingle(*enginePtr)
 }
 
 func init() {
index f7c994ffa7844ec5d3e99be41deffce168d5647a..34364161aa3bf50ef5caa723aee93ae2df0d7066 100644 (file)
@@ -5,6 +5,7 @@
 package gob
 
 import (
+       "bufio"
        "bytes"
        "io"
        "os"
@@ -21,7 +22,7 @@ type Decoder struct {
        wireType     map[typeId]*wireType                    // map from remote ID to local description
        decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
        ignorerCache map[typeId]**decEngine                  // ditto for ignored objects
-       countState   *decodeState                            // reads counts from wire
+       freeList     *decoderState                           // list of free decoderStates; avoids reallocation
        countBuf     []byte                                  // used for decoding integers while parsing messages
        tmp          []byte                                  // temporary storage for i/o; saves reallocating
        err          os.Error
@@ -30,7 +31,7 @@ type Decoder struct {
 // NewDecoder returns a new decoder that reads from the io.Reader.
 func NewDecoder(r io.Reader) *Decoder {
        dec := new(Decoder)
-       dec.r = r
+       dec.r = bufio.NewReader(r)
        dec.wireType = make(map[typeId]*wireType)
        dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
        dec.ignorerCache = make(map[typeId]**decEngine)
@@ -49,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) {
 
        // Type:
        wire := new(wireType)
-       dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire))
+       dec.decodeValue(tWireType, reflect.NewValue(wire))
        if dec.err != nil {
                return
        }
@@ -184,7 +185,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
        dec.err = nil
        id := dec.decodeTypeSequence(false)
        if dec.err == nil {
-               dec.err = dec.decodeValue(id, value)
+               dec.decodeValue(id, value)
        }
        return dec.err
 }
index e92db74ffdda3e0b743f109b6f9b7eff03511d38..5cfdb583a1804e57848d5aa642458efffab3eba8 100644 (file)
@@ -6,16 +6,14 @@ package gob
 
 import (
        "bytes"
-       "io"
        "math"
-       "os"
        "reflect"
        "unsafe"
 )
 
 const uint64Size = unsafe.Sizeof(uint64(0))
 
-// The global execution state of an instance of the encoder.
+// encoderState is the global execution state of an instance of the encoder.
 // Field numbers are delta encoded and always increase. The field
 // number is initialized to -1 so 0 comes out as delta(1). A delta of
 // 0 terminates the structure.
@@ -25,10 +23,26 @@ type encoderState struct {
        sendZero bool                 // encoding an array element or map key/value pair; send zero values
        fieldnum int                  // the last field number written.
        buf      [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
+       next     *encoderState        // for free list
 }
 
-func newEncoderState(enc *Encoder, b *bytes.Buffer) *encoderState {
-       return &encoderState{enc: enc, b: b}
+func (enc *Encoder) newEncoderState(b *bytes.Buffer) *encoderState {
+       e := enc.freeList
+       if e == nil {
+               e = new(encoderState)
+               e.enc = enc
+       } else {
+               enc.freeList = e.next
+       }
+       e.sendZero = false
+       e.fieldnum = 0
+       e.b = b
+       return e
+}
+
+func (enc *Encoder) freeEncoderState(e *encoderState) {
+       e.next = enc.freeList
+       enc.freeList = e
 }
 
 // Unsigned integers have a two-state encoding.  If the number is less
@@ -72,6 +86,7 @@ func (state *encoderState) encodeInt(i int64) {
        state.encodeUint(uint64(x))
 }
 
+// encOp is the signature of an encoding operator for a given type.
 type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer)
 
 // The 'instructions' of the encoding machine
@@ -82,8 +97,8 @@ type encInstr struct {
        offset uintptr // offset in the structure of the field to encode
 }
 
-// Emit a field number and update the state to record its value for delta encoding.
-// If the instruction pointer is nil, do nothing
+// update emits a field number and updates the state to record its value for delta encoding.
+// If the instruction pointer is nil, it does nothing
 func (state *encoderState) update(instr *encInstr) {
        if instr != nil {
                state.encodeUint(uint64(instr.field - state.fieldnum))
@@ -91,12 +106,16 @@ func (state *encoderState) update(instr *encInstr) {
        }
 }
 
-// Each encoder is responsible for handling any indirections associated
-// with the data structure.  If any pointer so reached is nil, no bytes are written.
-// If the data item is zero, no bytes are written.
-// Otherwise, the output (for a scalar) is the field number, as an encoded integer,
-// followed by the field data in its appropriate format.
+// Each encoder for a composite is responsible for handling any
+// indirections associated with the elements of the data structure.
+// If any pointer so reached is nil, no bytes are written.  If the
+// data item is zero, no bytes are written.  Single values - ints,
+// strings etc. - are indirected before calling their encoders.
+// Otherwise, the output (for a scalar) is the field number, as an
+// encoded integer, followed by the field data in its appropriate
+// format.
 
+// encIndirect dereferences p indir times and returns the result.
 func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
        for ; indir > 0; indir-- {
                p = *(*unsafe.Pointer)(p)
@@ -107,6 +126,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
        return p
 }
 
+// encBool encodes the bool with address p as an unsigned 0 or 1.
 func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
        b := *(*bool)(p)
        if b || state.sendZero {
@@ -119,6 +139,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encInt encodes the int with address p.
 func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int)(p))
        if v != 0 || state.sendZero {
@@ -127,6 +148,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encUint encodes the uint with address p.
 func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint)(p))
        if v != 0 || state.sendZero {
@@ -135,6 +157,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encInt8 encodes the int8 with address p.
 func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int8)(p))
        if v != 0 || state.sendZero {
@@ -143,6 +166,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encUint8 encodes the uint8 with address p.
 func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint8)(p))
        if v != 0 || state.sendZero {
@@ -151,6 +175,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encInt16 encodes the int16 with address p.
 func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int16)(p))
        if v != 0 || state.sendZero {
@@ -159,6 +184,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encUint16 encodes the uint16 with address p.
 func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint16)(p))
        if v != 0 || state.sendZero {
@@ -167,6 +193,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encInt32 encodes the int32 with address p.
 func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := int64(*(*int32)(p))
        if v != 0 || state.sendZero {
@@ -175,6 +202,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encUint encodes the uint32 with address p.
 func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uint32)(p))
        if v != 0 || state.sendZero {
@@ -183,6 +211,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encInt64 encodes the int64 with address p.
 func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := *(*int64)(p)
        if v != 0 || state.sendZero {
@@ -191,6 +220,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encInt64 encodes the uint64 with address p.
 func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := *(*uint64)(p)
        if v != 0 || state.sendZero {
@@ -199,6 +229,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encUintptr encodes the uintptr with address p.
 func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
        v := uint64(*(*uintptr)(p))
        if v != 0 || state.sendZero {
@@ -207,6 +238,7 @@ func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// floatBits returns a uint64 holding the bits of a floating-point number.
 // Floating-point numbers are transmitted as uint64s holding the bits
 // of the underlying representation.  They are sent byte-reversed, with
 // the exponent end coming out first, so integer floating point numbers
@@ -223,6 +255,7 @@ func floatBits(f float64) uint64 {
        return v
 }
 
+// encFloat32 encodes the float32 with address p.
 func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        f := *(*float32)(p)
        if f != 0 || state.sendZero {
@@ -232,6 +265,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encFloat64 encodes the float64 with address p.
 func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        f := *(*float64)(p)
        if f != 0 || state.sendZero {
@@ -241,6 +275,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encComplex64 encodes the complex64 with address p.
 // Complex numbers are just a pair of floating-point numbers, real part first.
 func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        c := *(*complex64)(p)
@@ -253,6 +288,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encComplex128 encodes the complex128 with address p.
 func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
        c := *(*complex128)(p)
        if c != 0+0i || state.sendZero {
@@ -264,6 +300,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encUint8Array encodes the byte slice whose header has address p.
 // Byte arrays are encoded as an unsigned count followed by the raw bytes.
 func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
        b := *(*[]byte)(p)
@@ -274,24 +311,26 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// encString encodes the string whose header has address p.
 // Strings are encoded as an unsigned count followed by the raw bytes.
 func encString(i *encInstr, state *encoderState, p unsafe.Pointer) {
        s := *(*string)(p)
        if len(s) > 0 || state.sendZero {
                state.update(i)
                state.encodeUint(uint64(len(s)))
-               io.WriteString(state.b, s)
+               state.b.WriteString(s)
        }
 }
 
-// The end of a struct is marked by a delta field number of 0.
+// encStructTerminator encodes the end of an encoded struct
+// as delta field number of 0.
 func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) {
        state.encodeUint(0)
 }
 
 // Execution engine
 
-// The encoder engine is an array of instructions indexed by field number of the encoding
+// encEngine an array of instructions indexed by field number of the encoding
 // data, typically a struct.  It is executed top to bottom, walking the struct.
 type encEngine struct {
        instr []encInstr
@@ -299,8 +338,9 @@ type encEngine struct {
 
 const singletonField = 0
 
+// encodeSingle encodes a single top-level non-struct value.
 func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintptr) {
-       state := newEncoderState(enc, b)
+       state := enc.newEncoderState(b)
        state.fieldnum = singletonField
        // There is no surrounding struct to frame the transmission, so we must
        // generate data even if the item is zero.  To do this, set sendZero.
@@ -313,10 +353,12 @@ func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintp
                }
        }
        instr.op(instr, state, p)
+       enc.freeEncoderState(state)
 }
 
+// encodeStruct encodes a single struct value.
 func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintptr) {
-       state := newEncoderState(enc, b)
+       state := enc.newEncoderState(b)
        state.fieldnum = -1
        for i := 0; i < len(engine.instr); i++ {
                instr := &engine.instr[i]
@@ -328,10 +370,12 @@ func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintp
                }
                instr.op(instr, state, p)
        }
+       enc.freeEncoderState(state)
 }
 
+// encodeArray encodes the array whose 0th element is at p.
 func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) {
-       state := newEncoderState(enc, b)
+       state := enc.newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
        state.encodeUint(uint64(length))
@@ -347,8 +391,10 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui
                op(nil, state, unsafe.Pointer(elemp))
                p += uintptr(elemWid)
        }
+       enc.freeEncoderState(state)
 }
 
+// encodeReflectValue is a helper for maps. It encodes the value v.
 func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
        for i := 0; i < indir && v != nil; i++ {
                v = reflect.Indirect(v)
@@ -359,8 +405,11 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
        op(nil, state, unsafe.Pointer(v.UnsafeAddr()))
 }
 
+// encodeMap encodes a map as unsigned count followed by key:value pairs.
+// Because map internals are not exposed, we must use reflection rather than
+// addresses.
 func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) {
-       state := newEncoderState(enc, b)
+       state := enc.newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
        keys := mv.Keys()
@@ -369,14 +418,16 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem
                encodeReflectValue(state, key, keyOp, keyIndir)
                encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir)
        }
+       enc.freeEncoderState(state)
 }
 
+// encodeInterface encodes the interface value iv.
 // To send an interface, we send a string identifying the concrete type, followed
 // by the type identifier (which might require defining that type right now), followed
 // by the concrete value.  A nil value gets sent as the empty string for the name,
 // followed by no value.
 func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) {
-       state := newEncoderState(enc, b)
+       state := enc.newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
        if iv.IsNil() {
@@ -391,7 +442,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
        }
        // Send the name.
        state.encodeUint(uint64(len(name)))
-       _, err := io.WriteString(state.b, name)
+       _, err := state.b.WriteString(name)
        if err != nil {
                error(err)
        }
@@ -403,15 +454,32 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
        // should be written to b, before the encoded value.
        enc.pushWriter(b)
        data := new(bytes.Buffer)
-       err = enc.encode(data, iv.Elem(), ut)
-       if err != nil {
-               error(err)
+       enc.encode(data, iv.Elem(), ut)
+       if enc.err != nil {
+               error(enc.err)
        }
        enc.popWriter()
        enc.writeMessage(b, data)
        if enc.err != nil {
                error(err)
        }
+       enc.freeEncoderState(state)
+}
+
+// encGobEncoder encodes a value that implements the GobEncoder interface.
+// The data is sent as a byte array.
+func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value, index int) {
+       // TODO: should we catch panics from the called method?
+       // We know it's a GobEncoder, so just call the method directly.
+       data, err := v.Interface().(GobEncoder).GobEncode()
+       if err != nil {
+               error(err)
+       }
+       state := enc.newEncoderState(b)
+       state.fieldnum = -1
+       state.encodeUint(uint64(len(data)))
+       state.b.Write(data)
+       enc.freeEncoderState(state)
 }
 
 var encOpTable = [...]encOp{
@@ -434,10 +502,14 @@ var encOpTable = [...]encOp{
        reflect.String:     encString,
 }
 
-// Return (a pointer to) the encoding op for the base type under rt and
+// encOpFor returns (a pointer to) the encoding op for the base type under rt and
 // the indirection count to reach it.
 func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
        ut := userType(rt)
+       // If the type implements GobEncoder, we handle it without further processing.
+       if ut.isGobEncoder {
+               return enc.gobEncodeOpFor(ut)
+       }
        // If this type is already in progress, it's a recursive type (e.g. map[string]*T).
        // Return the pointer to the op we're already building.
        if opPtr := inProgress[rt]; opPtr != nil {
@@ -483,7 +555,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
                                // Maps cannot be accessed by moving addresses around the way
                                // that slices etc. can.  We must recover a full reflection value for
                                // the iteration.
-                               v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p))))
+                               v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p)))
                                mv := reflect.Indirect(v).(*reflect.MapValue)
                                if !state.sendZero && mv.Len() == 0 {
                                        return
@@ -493,7 +565,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
                        }
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
-                       enc.getEncEngine(typ)
+                       enc.getEncEngine(userType(typ))
                        info := mustGetTypeInfo(typ)
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                state.update(i)
@@ -504,7 +576,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                // Interfaces transmit the name and contents of the concrete
                                // value they contain.
-                               v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p))))
+                               v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p)))
                                iv := reflect.Indirect(v).(*reflect.InterfaceValue)
                                if !state.sendZero && (iv == nil || iv.IsNil()) {
                                        return
@@ -520,22 +592,64 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
        return &op, indir
 }
 
-// The local Type was compiled from the actual value, so we know it's compatible.
-func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
-       srt, isStruct := rt.(*reflect.StructType)
+// methodIndex returns which method of rt implements the method.
+func methodIndex(rt reflect.Type, method string) int {
+       for i := 0; i < rt.NumMethod(); i++ {
+               if rt.Method(i).Name == method {
+                       return i
+               }
+       }
+       errorf("gob: internal error: can't find method %s", method)
+       return 0
+}
+
+// gobEncodeOpFor returns the op for a type that is known to implement
+// GobEncoder.
+func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
+       rt := ut.user
+       if ut.encIndir == -1 {
+               rt = reflect.PtrTo(rt)
+       } else if ut.encIndir > 0 {
+               for i := int8(0); i < ut.encIndir; i++ {
+                       rt = rt.(*reflect.PtrType).Elem()
+               }
+       }
+       var op encOp
+       op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
+               var v reflect.Value
+               if ut.encIndir == -1 {
+                       // Need to climb up one level to turn value into pointer.
+                       v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p)))
+               } else {
+                       v = reflect.NewValue(unsafe.Unreflect(rt, p))
+               }
+               state.update(i)
+               state.enc.encodeGobEncoder(state.b, v, methodIndex(rt, gobEncodeMethodName))
+       }
+       return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver.
+}
+
+// compileEnc returns the engine to compile the type.
+func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine {
+       srt, isStruct := ut.base.(*reflect.StructType)
        engine := new(encEngine)
        seen := make(map[reflect.Type]*encOp)
-       if isStruct {
-               for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
+       rt := ut.base
+       if ut.isGobEncoder {
+               rt = ut.user
+       }
+       if !ut.isGobEncoder && isStruct {
+               for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ {
                        f := srt.Field(fieldNum)
                        if !isExported(f.Name) {
                                continue
                        }
                        op, indir := enc.encOpFor(f.Type, seen)
-                       engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)})
+                       engine.instr = append(engine.instr, encInstr{*op, wireFieldNum, indir, uintptr(f.Offset)})
+                       wireFieldNum++
                }
                if srt.NumField() > 0 && len(engine.instr) == 0 {
-                       errorf("type %s has no exported fields", rt)
+                       errorf("gob: type %s has no exported fields", rt)
                }
                engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0})
        } else {
@@ -546,38 +660,42 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
        return engine
 }
 
+// getEncEngine returns the engine to compile the type.
 // typeLock must be held (or we're in initialization and guaranteed single-threaded).
-// The reflection type must have all its indirections processed out.
-func (enc *Encoder) getEncEngine(rt reflect.Type) *encEngine {
-       info, err1 := getTypeInfo(rt)
+func (enc *Encoder) getEncEngine(ut *userTypeInfo) *encEngine {
+       info, err1 := getTypeInfo(ut)
        if err1 != nil {
                error(err1)
        }
        if info.encoder == nil {
                // mark this engine as underway before compiling to handle recursive types.
                info.encoder = new(encEngine)
-               info.encoder = enc.compileEnc(rt)
+               info.encoder = enc.compileEnc(ut)
        }
        return info.encoder
 }
 
-// Put this in a function so we can hold the lock only while compiling, not when encoding.
-func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine {
+// lockAndGetEncEngine is a function that locks and compiles.
+// This lets us hold the lock only while compiling, not when encoding.
+func (enc *Encoder) lockAndGetEncEngine(ut *userTypeInfo) *encEngine {
        typeLock.Lock()
        defer typeLock.Unlock()
-       return enc.getEncEngine(rt)
+       return enc.getEncEngine(ut)
 }
 
-func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) {
-       defer catchError(&err)
-       for i := 0; i < ut.indir; i++ {
+func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) {
+       defer catchError(&enc.err)
+       engine := enc.lockAndGetEncEngine(ut)
+       indir := ut.indir
+       if ut.isGobEncoder {
+               indir = int(ut.encIndir)
+       }
+       for i := 0; i < indir; i++ {
                value = reflect.Indirect(value)
        }
-       engine := enc.lockAndGetEncEngine(ut.base)
-       if value.Type().Kind() == reflect.Struct {
+       if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct {
                enc.encodeStruct(b, engine, value.UnsafeAddr())
        } else {
                enc.encodeSingle(b, engine, value.UnsafeAddr())
        }
-       return nil
 }
index 92d036c11c3578846dbbaaf6e637d890ee3676ee..e52a4de29f7117c9f9b3d274a96835daaeb0de38 100644 (file)
@@ -19,7 +19,9 @@ type Encoder struct {
        w          []io.Writer             // where to send the data
        sent       map[reflect.Type]typeId // which types we've already sent
        countState *encoderState           // stage for writing counts
+       freeList   *encoderState           // list of free encoderStates; avoids reallocation
        buf        []byte                  // for collecting the output.
+       byteBuf    bytes.Buffer            // buffer for top-level encoderState
        err        os.Error
 }
 
@@ -28,7 +30,7 @@ func NewEncoder(w io.Writer) *Encoder {
        enc := new(Encoder)
        enc.w = []io.Writer{w}
        enc.sent = make(map[reflect.Type]typeId)
-       enc.countState = newEncoderState(enc, new(bytes.Buffer))
+       enc.countState = enc.newEncoderState(new(bytes.Buffer))
        return enc
 }
 
@@ -78,12 +80,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
        }
 }
 
+// sendActualType sends the requested type, without further investigation, unless
+// it's been sent before.
+func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
+       if _, alreadySent := enc.sent[actual]; alreadySent {
+               return false
+       }
+       typeLock.Lock()
+       info, err := getTypeInfo(ut)
+       typeLock.Unlock()
+       if err != nil {
+               enc.setError(err)
+               return
+       }
+       // Send the pair (-id, type)
+       // Id:
+       state.encodeInt(-int64(info.id))
+       // Type:
+       enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
+       enc.writeMessage(w, state.b)
+       if enc.err != nil {
+               return
+       }
+
+       // Remember we've sent this type, both what the user gave us and the base type.
+       enc.sent[ut.base] = info.id
+       if ut.user != ut.base {
+               enc.sent[ut.user] = info.id
+       }
+       // Now send the inner types
+       switch st := actual.(type) {
+       case *reflect.StructType:
+               for i := 0; i < st.NumField(); i++ {
+                       enc.sendType(w, state, st.Field(i).Type)
+               }
+       case reflect.ArrayOrSliceType:
+               enc.sendType(w, state, st.Elem())
+       }
+       return true
+}
+
+// sendType sends the type info to the other side, if necessary. 
 func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
-       // Drill down to the base type.
        ut := userType(origt)
-       rt := ut.base
+       if ut.isGobEncoder {
+               // The rules are different: regardless of the underlying type's representation,
+               // we need to tell the other side that this exact type is a GobEncoder.
+               return enc.sendActualType(w, state, ut, ut.user)
+       }
 
-       switch rt := rt.(type) {
+       // It's a concrete value, so drill down to the base type.
+       switch rt := ut.base.(type) {
        default:
                // Basic types and interfaces do not need to be described.
                return
@@ -109,43 +156,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
                return
        }
 
-       // Have we already sent this type?  This time we ask about the base type.
-       if _, alreadySent := enc.sent[rt]; alreadySent {
-               return
-       }
-
-       // Need to send it.
-       typeLock.Lock()
-       info, err := getTypeInfo(rt)
-       typeLock.Unlock()
-       if err != nil {
-               enc.setError(err)
-               return
-       }
-       // Send the pair (-id, type)
-       // Id:
-       state.encodeInt(-int64(info.id))
-       // Type:
-       enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
-       enc.writeMessage(w, state.b)
-       if enc.err != nil {
-               return
-       }
-
-       // Remember we've sent this type.
-       enc.sent[rt] = info.id
-       // Remember we've sent the top-level, possibly indirect type too.
-       enc.sent[origt] = info.id
-       // Now send the inner types
-       switch st := rt.(type) {
-       case *reflect.StructType:
-               for i := 0; i < st.NumField(); i++ {
-                       enc.sendType(w, state, st.Field(i).Type)
-               }
-       case reflect.ArrayOrSliceType:
-               enc.sendType(w, state, st.Elem())
-       }
-       return true
+       return enc.sendActualType(w, state, ut, ut.base)
 }
 
 // Encode transmits the data item represented by the empty interface value,
@@ -159,11 +170,14 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
 // sent.
 func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
        // Make sure the type is known to the other side.
-       // First, have we already sent this (base) type?
-       base := ut.base
-       if _, alreadySent := enc.sent[base]; !alreadySent {
+       // First, have we already sent this type?
+       rt := ut.base
+       if ut.isGobEncoder {
+               rt = ut.user
+       }
+       if _, alreadySent := enc.sent[rt]; !alreadySent {
                // No, so send it.
-               sent := enc.sendType(w, state, base)
+               sent := enc.sendType(w, state, rt)
                if enc.err != nil {
                        return
                }
@@ -172,13 +186,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use
                // need to send the type info but we do need to update enc.sent.
                if !sent {
                        typeLock.Lock()
-                       info, err := getTypeInfo(base)
+                       info, err := getTypeInfo(ut)
                        typeLock.Unlock()
                        if err != nil {
                                enc.setError(err)
                                return
                        }
-                       enc.sent[base] = info.id
+                       enc.sent[rt] = info.id
                }
        }
 }
@@ -206,7 +220,8 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        }
 
        enc.err = nil
-       state := newEncoderState(enc, new(bytes.Buffer))
+       enc.byteBuf.Reset()
+       state := enc.newEncoderState(&enc.byteBuf)
 
        enc.sendTypeDescriptor(enc.writer(), state, ut)
        enc.sendTypeId(state, ut)
@@ -215,12 +230,11 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        }
 
        // Encode the object.
-       err = enc.encode(state.b, value, ut)
-       if err != nil {
-               enc.setError(err)
-       } else {
+       enc.encode(state.b, value, ut)
+       if enc.err == nil {
                enc.writeMessage(enc.writer(), state.b)
        }
 
+       enc.freeEncoderState(state)
        return enc.err
 }
diff --git a/libgo/go/gob/gobencdec_test.go b/libgo/go/gob/gobencdec_test.go
new file mode 100644 (file)
index 0000000..012b099
--- /dev/null
@@ -0,0 +1,384 @@
+// Copyright 20011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains tests of the GobEncoder/GobDecoder support.
+
+package gob
+
+import (
+       "bytes"
+       "fmt"
+       "os"
+       "strings"
+       "testing"
+)
+
+// Types that implement the GobEncoder/Decoder interfaces.
+
+type ByteStruct struct {
+       a byte // not an exported field
+}
+
+type StringStruct struct {
+       s string // not an exported field
+}
+
+type Gobber int
+
+type ValueGobber string // encodes with a value, decodes with a pointer.
+
+// The relevant methods
+
+func (g *ByteStruct) GobEncode() ([]byte, os.Error) {
+       b := make([]byte, 3)
+       b[0] = g.a
+       b[1] = g.a + 1
+       b[2] = g.a + 2
+       return b, nil
+}
+
+func (g *ByteStruct) GobDecode(data []byte) os.Error {
+       if g == nil {
+               return os.ErrorString("NIL RECEIVER")
+       }
+       // Expect N sequential-valued bytes.
+       if len(data) == 0 {
+               return os.EOF
+       }
+       g.a = data[0]
+       for i, c := range data {
+               if c != g.a+byte(i) {
+                       return os.ErrorString("invalid data sequence")
+               }
+       }
+       return nil
+}
+
+func (g *StringStruct) GobEncode() ([]byte, os.Error) {
+       return []byte(g.s), nil
+}
+
+func (g *StringStruct) GobDecode(data []byte) os.Error {
+       // Expect N sequential-valued bytes.
+       if len(data) == 0 {
+               return os.EOF
+       }
+       a := data[0]
+       for i, c := range data {
+               if c != a+byte(i) {
+                       return os.ErrorString("invalid data sequence")
+               }
+       }
+       g.s = string(data)
+       return nil
+}
+
+func (g *Gobber) GobEncode() ([]byte, os.Error) {
+       return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
+}
+
+func (g *Gobber) GobDecode(data []byte) os.Error {
+       _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
+       return err
+}
+
+func (v ValueGobber) GobEncode() ([]byte, os.Error) {
+       return []byte(fmt.Sprintf("VALUE=%s", v)), nil
+}
+
+func (v *ValueGobber) GobDecode(data []byte) os.Error {
+       _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
+       return err
+}
+
+// Structs that include GobEncodable fields.
+
+type GobTest0 struct {
+       X int // guarantee we have  something in common with GobTest*
+       G *ByteStruct
+}
+
+type GobTest1 struct {
+       X int // guarantee we have  something in common with GobTest*
+       G *StringStruct
+}
+
+type GobTest2 struct {
+       X int    // guarantee we have  something in common with GobTest*
+       G string // not a GobEncoder - should give us errors
+}
+
+type GobTest3 struct {
+       X int // guarantee we have  something in common with GobTest*
+       G *Gobber
+}
+
+type GobTest4 struct {
+       X int // guarantee we have  something in common with GobTest*
+       V ValueGobber
+}
+
+type GobTest5 struct {
+       X int // guarantee we have  something in common with GobTest*
+       V *ValueGobber
+}
+
+type GobTestIgnoreEncoder struct {
+       X int // guarantee we have  something in common with GobTest*
+}
+
+type GobTestValueEncDec struct {
+       X int          // guarantee we have  something in common with GobTest*
+       G StringStruct // not a pointer.
+}
+
+type GobTestIndirectEncDec struct {
+       X int             // guarantee we have  something in common with GobTest*
+       G ***StringStruct // indirections to the receiver.
+}
+
+func TestGobEncoderField(t *testing.T) {
+       b := new(bytes.Buffer)
+       // First a field that's a structure.
+       enc := NewEncoder(b)
+       err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTest0)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x.G.a != 'A' {
+               t.Errorf("expected 'A' got %c", x.G.a)
+       }
+       // Now a field that's not a structure.
+       b.Reset()
+       gobber := Gobber(23)
+       err = enc.Encode(GobTest3{17, &gobber})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       y := new(GobTest3)
+       err = dec.Decode(y)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if *y.G != 23 {
+               t.Errorf("expected '23 got %d", *y.G)
+       }
+}
+
+// Even though the field is a value, we can still take its address
+// and should be able to call the methods.
+func TestGobEncoderValueField(t *testing.T) {
+       b := new(bytes.Buffer)
+       // First a field that's a structure.
+       enc := NewEncoder(b)
+       err := enc.Encode(GobTestValueEncDec{17, StringStruct{"HIJKL"}})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTestValueEncDec)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x.G.s != "HIJKL" {
+               t.Errorf("expected `HIJKL` got %s", x.G.s)
+       }
+}
+
+// GobEncode/Decode should work even if the value is
+// more indirect than the receiver.
+func TestGobEncoderIndirectField(t *testing.T) {
+       b := new(bytes.Buffer)
+       // First a field that's a structure.
+       enc := NewEncoder(b)
+       s := &StringStruct{"HIJKL"}
+       sp := &s
+       err := enc.Encode(GobTestIndirectEncDec{17, &sp})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTestIndirectEncDec)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if (***x.G).s != "HIJKL" {
+               t.Errorf("expected `HIJKL` got %s", (***x.G).s)
+       }
+}
+
+// As long as the fields have the same name and implement the
+// interface, we can cross-connect them.  Not sure it's useful
+// and may even be bad but it works and it's hard to prevent
+// without exposing the contents of the object, which would
+// defeat the purpose.
+func TestGobEncoderFieldsOfDifferentType(t *testing.T) {
+       // first, string in field to byte in field
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTest0)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x.G.a != 'A' {
+               t.Errorf("expected 'A' got %c", x.G.a)
+       }
+       // now the other direction, byte in field to string in field
+       b.Reset()
+       err = enc.Encode(GobTest0{17, &ByteStruct{'X'}})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       y := new(GobTest1)
+       err = dec.Decode(y)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if y.G.s != "XYZ" {
+               t.Fatalf("expected `XYZ` got %c", y.G.s)
+       }
+}
+
+// Test that we can encode a value and decode into a pointer.
+func TestGobEncoderValueEncoder(t *testing.T) {
+       // first, string in field to byte in field
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       err := enc.Encode(GobTest4{17, ValueGobber("hello")})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTest5)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if *x.V != "hello" {
+               t.Errorf("expected `hello` got %s", x.V)
+       }
+}
+
+func TestGobEncoderFieldTypeError(t *testing.T) {
+       // GobEncoder to non-decoder: error
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := &GobTest2{}
+       err = dec.Decode(x)
+       if err == nil {
+               t.Fatal("expected decode error for mismatched fields (encoder to non-decoder)")
+       }
+       if strings.Index(err.String(), "type") < 0 {
+               t.Fatal("expected type error; got", err)
+       }
+       // Non-encoder to GobDecoder: error
+       b.Reset()
+       err = enc.Encode(GobTest2{17, "ABC"})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       y := &GobTest1{}
+       err = dec.Decode(y)
+       if err == nil {
+               t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)")
+       }
+       if strings.Index(err.String(), "type") < 0 {
+               t.Fatal("expected type error; got", err)
+       }
+}
+
+// Even though ByteStruct is a struct, it's treated as a singleton at the top level.
+func TestGobEncoderStructSingleton(t *testing.T) {
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       err := enc.Encode(&ByteStruct{'A'})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(ByteStruct)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x.a != 'A' {
+               t.Errorf("expected 'A' got %c", x.a)
+       }
+}
+
+func TestGobEncoderNonStructSingleton(t *testing.T) {
+       b := new(bytes.Buffer)
+       enc := NewEncoder(b)
+       err := enc.Encode(Gobber(1234))
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       var x Gobber
+       err = dec.Decode(&x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x != 1234 {
+               t.Errorf("expected 1234 got %c", x)
+       }
+}
+
+func TestGobEncoderIgnoreStructField(t *testing.T) {
+       b := new(bytes.Buffer)
+       // First a field that's a structure.
+       enc := NewEncoder(b)
+       err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTestIgnoreEncoder)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x.X != 17 {
+               t.Errorf("expected 17 got %c", x.X)
+       }
+}
+
+func TestGobEncoderIgnoreNonStructField(t *testing.T) {
+       b := new(bytes.Buffer)
+       // First a field that's a structure.
+       enc := NewEncoder(b)
+       gobber := Gobber(23)
+       err := enc.Encode(GobTest3{17, &gobber})
+       if err != nil {
+               t.Fatal("encode error:", err)
+       }
+       dec := NewDecoder(b)
+       x := new(GobTestIgnoreEncoder)
+       err = dec.Decode(x)
+       if err != nil {
+               t.Fatal("decode error:", err)
+       }
+       if x.X != 17 {
+               t.Errorf("expected 17 got %c", x.X)
+       }
+}
diff --git a/libgo/go/gob/timing_test.go b/libgo/go/gob/timing_test.go
new file mode 100644 (file)
index 0000000..645f4fe
--- /dev/null
@@ -0,0 +1,90 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+       "bytes"
+       "fmt"
+       "io"
+       "os"
+       "runtime"
+       "testing"
+)
+
+type Bench struct {
+       A int
+       B float64
+       C string
+       D []byte
+}
+
+func benchmarkEndToEnd(r io.Reader, w io.Writer, b *testing.B) {
+       b.StopTimer()
+       enc := NewEncoder(w)
+       dec := NewDecoder(r)
+       bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
+       b.StartTimer()
+       for i := 0; i < b.N; i++ {
+               if enc.Encode(bench) != nil {
+                       panic("encode error")
+               }
+               if dec.Decode(bench) != nil {
+                       panic("decode error")
+               }
+       }
+}
+
+func BenchmarkEndToEndPipe(b *testing.B) {
+       r, w, err := os.Pipe()
+       if err != nil {
+               panic("can't get pipe:" + err.String())
+       }
+       benchmarkEndToEnd(r, w, b)
+}
+
+func BenchmarkEndToEndByteBuffer(b *testing.B) {
+       var buf bytes.Buffer
+       benchmarkEndToEnd(&buf, &buf, b)
+}
+
+func TestCountEncodeMallocs(t *testing.T) {
+       var buf bytes.Buffer
+       enc := NewEncoder(&buf)
+       bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
+       mallocs := 0 - runtime.MemStats.Mallocs
+       const count = 1000
+       for i := 0; i < count; i++ {
+               err := enc.Encode(bench)
+               if err != nil {
+                       t.Fatal("encode:", err)
+               }
+       }
+       mallocs += runtime.MemStats.Mallocs
+       fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count)
+}
+
+func TestCountDecodeMallocs(t *testing.T) {
+       var buf bytes.Buffer
+       enc := NewEncoder(&buf)
+       bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
+       const count = 1000
+       for i := 0; i < count; i++ {
+               err := enc.Encode(bench)
+               if err != nil {
+                       t.Fatal("encode:", err)
+               }
+       }
+       dec := NewDecoder(&buf)
+       mallocs := 0 - runtime.MemStats.Mallocs
+       for i := 0; i < count; i++ {
+               *bench = Bench{}
+               err := dec.Decode(&bench)
+               if err != nil {
+                       t.Fatal("decode:", err)
+               }
+       }
+       mallocs += runtime.MemStats.Mallocs
+       fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count)
+}
index 6e3f148b4e7d3e95e3a773b064e0533315ebb9d2..fc620f5c7c1011967e997576a97e73cf36517587 100644 (file)
@@ -9,15 +9,21 @@ import (
        "os"
        "reflect"
        "sync"
+       "unicode"
+       "utf8"
 )
 
 // userTypeInfo stores the information associated with a type the user has handed
 // to the package.  It's computed once and stored in a map keyed by reflection
 // type.
 type userTypeInfo struct {
-       user  reflect.Type // the type the user handed us
-       base  reflect.Type // the base type after all indirections
-       indir int          // number of indirections to reach the base type
+       user         reflect.Type // the type the user handed us
+       base         reflect.Type // the base type after all indirections
+       indir        int          // number of indirections to reach the base type
+       isGobEncoder bool         // does the type implement GobEncoder?
+       isGobDecoder bool         // does the type implement GobDecoder?
+       encIndir     int8         // number of indirections to reach the receiver type; may be negative
+       decIndir     int8         // number of indirections to reach the receiver type; may be negative
 }
 
 var (
@@ -68,10 +74,73 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
                }
                ut.indir++
        }
+       ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck)
+       ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck)
        userTypeCache[rt] = ut
        return
 }
 
+const (
+       gobEncodeMethodName = "GobEncode"
+       gobDecodeMethodName = "GobDecode"
+)
+
+// implements returns whether the type implements the interface, as encoded
+// in the check function.
+func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool {
+       if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance
+               return false
+       }
+       return check(typ)
+}
+
+// gobEncoderCheck makes the type assertion a boolean function.
+func gobEncoderCheck(typ reflect.Type) bool {
+       _, ok := reflect.MakeZero(typ).Interface().(GobEncoder)
+       return ok
+}
+
+// gobDecoderCheck makes the type assertion a boolean function.
+func gobDecoderCheck(typ reflect.Type) bool {
+       _, ok := reflect.MakeZero(typ).Interface().(GobDecoder)
+       return ok
+}
+
+// implementsInterface reports whether the type implements the
+// interface. (The actual check is done through the provided function.)
+// It also returns the number of indirections required to get to the
+// implementation.
+func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) {
+       if typ == nil {
+               return
+       }
+       rt := typ
+       // The type might be a pointer and we need to keep
+       // dereferencing to the base type until we find an implementation.
+       for {
+               if implements(rt, check) {
+                       return true, indir
+               }
+               if p, ok := rt.(*reflect.PtrType); ok {
+                       indir++
+                       if indir > 100 { // insane number of indirections
+                               return false, 0
+                       }
+                       rt = p.Elem()
+                       continue
+               }
+               break
+       }
+       // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
+       if _, ok := typ.(*reflect.PtrType); !ok {
+               // Not a pointer, but does the pointer work?
+               if implements(reflect.PtrTo(typ), check) {
+                       return true, -1
+               }
+       }
+       return false, 0
+}
+
 // userType returns, and saves, the information associated with user-provided type rt.
 // If the user type is not valid, it calls error.
 func userType(rt reflect.Type) *userTypeInfo {
@@ -153,22 +222,24 @@ func (t *CommonType) name() string { return t.Name }
 
 var (
        // Primordial types, needed during initialization.
-       tBool      = bootstrapType("bool", false, 1)
-       tInt       = bootstrapType("int", int(0), 2)
-       tUint      = bootstrapType("uint", uint(0), 3)
-       tFloat     = bootstrapType("float", float64(0), 4)
-       tBytes     = bootstrapType("bytes", make([]byte, 0), 5)
-       tString    = bootstrapType("string", "", 6)
-       tComplex   = bootstrapType("complex", 0+0i, 7)
-       tInterface = bootstrapType("interface", interface{}(nil), 8)
+       // Always passed as pointers so the interface{} type
+       // goes through without losing its interfaceness.
+       tBool      = bootstrapType("bool", (*bool)(nil), 1)
+       tInt       = bootstrapType("int", (*int)(nil), 2)
+       tUint      = bootstrapType("uint", (*uint)(nil), 3)
+       tFloat     = bootstrapType("float", (*float64)(nil), 4)
+       tBytes     = bootstrapType("bytes", (*[]byte)(nil), 5)
+       tString    = bootstrapType("string", (*string)(nil), 6)
+       tComplex   = bootstrapType("complex", (*complex128)(nil), 7)
+       tInterface = bootstrapType("interface", (*interface{})(nil), 8)
        // Reserve some Ids for compatible expansion
-       tReserved7 = bootstrapType("_reserved1", struct{ r7 int }{}, 9)
-       tReserved6 = bootstrapType("_reserved1", struct{ r6 int }{}, 10)
-       tReserved5 = bootstrapType("_reserved1", struct{ r5 int }{}, 11)
-       tReserved4 = bootstrapType("_reserved1", struct{ r4 int }{}, 12)
-       tReserved3 = bootstrapType("_reserved1", struct{ r3 int }{}, 13)
-       tReserved2 = bootstrapType("_reserved1", struct{ r2 int }{}, 14)
-       tReserved1 = bootstrapType("_reserved1", struct{ r1 int }{}, 15)
+       tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9)
+       tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10)
+       tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11)
+       tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12)
+       tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13)
+       tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14)
+       tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15)
 )
 
 // Predefined because it's needed by the Decoder
@@ -229,6 +300,23 @@ func (a *arrayType) safeString(seen map[typeId]bool) string {
 
 func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
 
+// GobEncoder type (something that implements the GobEncoder interface)
+type gobEncoderType struct {
+       CommonType
+}
+
+func newGobEncoderType(name string) *gobEncoderType {
+       g := &gobEncoderType{CommonType{Name: name}}
+       setTypeId(g)
+       return g
+}
+
+func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
+       return g.Name
+}
+
+func (g *gobEncoderType) string() string { return g.Name }
+
 // Map type
 type mapType struct {
        CommonType
@@ -324,11 +412,16 @@ func newStructType(name string) *structType {
        return s
 }
 
-func (s *structType) init(field []*fieldType) {
-       s.Field = field
-}
-
-func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
+// newTypeObject allocates a gobType for the reflection type rt.
+// Unless ut represents a GobEncoder, rt should be the base type
+// of ut.
+// This is only called from the encoding side. The decoding side
+// works through typeIds and userTypeInfos alone.
+func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+       // Does this type implement GobEncoder?
+       if ut.isGobEncoder {
+               return newGobEncoderType(name), nil
+       }
        var err os.Error
        var type0, type1 gobType
        defer func() {
@@ -364,7 +457,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
        case *reflect.ArrayType:
                at := newArrayType(name)
                types[rt] = at
-               type0, err = getType("", t.Elem())
+               type0, err = getBaseType("", t.Elem())
                if err != nil {
                        return nil, err
                }
@@ -382,11 +475,11 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
        case *reflect.MapType:
                mt := newMapType(name)
                types[rt] = mt
-               type0, err = getType("", t.Key())
+               type0, err = getBaseType("", t.Key())
                if err != nil {
                        return nil, err
                }
-               type1, err = getType("", t.Elem())
+               type1, err = getBaseType("", t.Elem())
                if err != nil {
                        return nil, err
                }
@@ -400,7 +493,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
                }
                st := newSliceType(name)
                types[rt] = st
-               type0, err = getType(t.Elem().Name(), t.Elem())
+               type0, err = getBaseType(t.Elem().Name(), t.Elem())
                if err != nil {
                        return nil, err
                }
@@ -411,22 +504,23 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
                st := newStructType(name)
                types[rt] = st
                idToType[st.id()] = st
-               field := make([]*fieldType, t.NumField())
                for i := 0; i < t.NumField(); i++ {
                        f := t.Field(i)
+                       if !isExported(f.Name) {
+                               continue
+                       }
                        typ := userType(f.Type).base
                        tname := typ.Name()
                        if tname == "" {
                                t := userType(f.Type).base
                                tname = t.String()
                        }
-                       gt, err := getType(tname, f.Type)
+                       gt, err := getBaseType(tname, f.Type)
                        if err != nil {
                                return nil, err
                        }
-                       field[i] = &fieldType{f.Name, gt.id()}
+                       st.Field = append(st.Field, &fieldType{f.Name, gt.id()})
                }
-               st.init(field)
                return st, nil
 
        default:
@@ -435,15 +529,30 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
        return nil, nil
 }
 
+// isExported reports whether this is an exported - upper case - name.
+func isExported(name string) bool {
+       rune, _ := utf8.DecodeRuneInString(name)
+       return unicode.IsUpper(rune)
+}
+
+// getBaseType returns the Gob type describing the given reflect.Type's base type.
+// typeLock must be held.
+func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
+       ut := userType(rt)
+       return getType(name, ut, ut.base)
+}
+
 // getType returns the Gob type describing the given reflect.Type.
+// Should be called only when handling GobEncoders/Decoders,
+// which may be pointers.  All other types are handled through the
+//  base type, never a pointer.
 // typeLock must be held.
-func getType(name string, rt reflect.Type) (gobType, os.Error) {
-       rt = userType(rt).base
+func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
        typ, present := types[rt]
        if present {
                return typ, nil
        }
-       typ, err := newTypeObject(name, rt)
+       typ, err := newTypeObject(name, ut, rt)
        if err == nil {
                types[rt] = typ
        }
@@ -457,9 +566,10 @@ func checkId(want, got typeId) {
        }
 }
 
-// used for building the basic types; called only from init()
+// used for building the basic types; called only from init().  the incoming
+// interface always refers to a pointer.
 func bootstrapType(name string, e interface{}, expect typeId) typeId {
-       rt := reflect.Typeof(e)
+       rt := reflect.Typeof(e).(*reflect.PtrType).Elem()
        _, present := types[rt]
        if present {
                panic("bootstrap type already present: " + name + ", " + rt.String())
@@ -484,10 +594,11 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
 // To maintain binary compatibility, if you extend this type, always put
 // the new fields last.
 type wireType struct {
-       ArrayT  *arrayType
-       SliceT  *sliceType
-       StructT *structType
-       MapT    *mapType
+       ArrayT      *arrayType
+       SliceT      *sliceType
+       StructT     *structType
+       MapT        *mapType
+       GobEncoderT *gobEncoderType
 }
 
 func (w *wireType) string() string {
@@ -504,6 +615,8 @@ func (w *wireType) string() string {
                return w.StructT.Name
        case w.MapT != nil:
                return w.MapT.Name
+       case w.GobEncoderT != nil:
+               return w.GobEncoderT.Name
        }
        return unknown
 }
@@ -516,49 +629,88 @@ type typeInfo struct {
 
 var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
 
-// The reflection type must have all its indirections processed out.
 // typeLock must be held.
-func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
-       if rt.Kind() == reflect.Ptr {
-               panic("pointer type in getTypeInfo: " + rt.String())
+func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
+       rt := ut.base
+       if ut.isGobEncoder {
+               // We want the user type, not the base type.
+               rt = ut.user
        }
        info, ok := typeInfoMap[rt]
-       if !ok {
-               info = new(typeInfo)
-               name := rt.Name()
-               gt, err := getType(name, rt)
+       if ok {
+               return info, nil
+       }
+       info = new(typeInfo)
+       gt, err := getBaseType(rt.Name(), rt)
+       if err != nil {
+               return nil, err
+       }
+       info.id = gt.id()
+
+       if ut.isGobEncoder {
+               userType, err := getType(rt.Name(), ut, rt)
                if err != nil {
                        return nil, err
                }
-               info.id = gt.id()
-               t := info.id.gobType()
-               switch typ := rt.(type) {
-               case *reflect.ArrayType:
-                       info.wire = &wireType{ArrayT: t.(*arrayType)}
-               case *reflect.MapType:
-                       info.wire = &wireType{MapT: t.(*mapType)}
-               case *reflect.SliceType:
-                       // []byte == []uint8 is a special case handled separately
-                       if typ.Elem().Kind() != reflect.Uint8 {
-                               info.wire = &wireType{SliceT: t.(*sliceType)}
-                       }
-               case *reflect.StructType:
-                       info.wire = &wireType{StructT: t.(*structType)}
+               info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
+               typeInfoMap[ut.user] = info
+               return info, nil
+       }
+
+       t := info.id.gobType()
+       switch typ := rt.(type) {
+       case *reflect.ArrayType:
+               info.wire = &wireType{ArrayT: t.(*arrayType)}
+       case *reflect.MapType:
+               info.wire = &wireType{MapT: t.(*mapType)}
+       case *reflect.SliceType:
+               // []byte == []uint8 is a special case handled separately
+               if typ.Elem().Kind() != reflect.Uint8 {
+                       info.wire = &wireType{SliceT: t.(*sliceType)}
                }
-               typeInfoMap[rt] = info
+       case *reflect.StructType:
+               info.wire = &wireType{StructT: t.(*structType)}
        }
+       typeInfoMap[rt] = info
        return info, nil
 }
 
 // Called only when a panic is acceptable and unexpected.
 func mustGetTypeInfo(rt reflect.Type) *typeInfo {
-       t, err := getTypeInfo(rt)
+       t, err := getTypeInfo(userType(rt))
        if err != nil {
                panic("getTypeInfo: " + err.String())
        }
        return t
 }
 
+// GobEncoder is the interface describing data that provides its own
+// representation for encoding values for transmission to a GobDecoder.
+// A type that implements GobEncoder and GobDecoder has complete
+// control over the representation of its data and may therefore
+// contain things such as private fields, channels, and functions,
+// which are not usually transmissable in gob streams.
+//
+// Note: Since gobs can be stored permanently, It is good design
+// to guarantee the encoding used by a GobEncoder is stable as the
+// software evolves.  For instance, it might make sense for GobEncode
+// to include a version number in the encoding.
+type GobEncoder interface {
+       // GobEncode returns a byte slice representing the encoding of the
+       // receiver for transmission to a GobDecoder, usually of the same
+       // concrete type.
+       GobEncode() ([]byte, os.Error)
+}
+
+// GobDecoder is the interface describing data that provides its own
+// routine for decoding transmitted values sent by a GobEncoder.
+type GobDecoder interface {
+       // GobDecode overwrites the receiver, which must be a pointer,
+       // with the value represented by the byte slice, which was written
+       // by GobEncode, usually for the same concrete type.
+       GobDecode([]byte) os.Error
+}
+
 var (
        nameToConcreteType = make(map[string]reflect.Type)
        concreteTypeToName = make(map[reflect.Type]string)
index 5aecde103a5c23c963530470a37f7a739a184987..ffd1345e5c0c135a911ef0298b7f261e323051be 100644 (file)
@@ -26,7 +26,7 @@ var basicTypes = []typeT{
 func getTypeUnlocked(name string, rt reflect.Type) gobType {
        typeLock.Lock()
        defer typeLock.Unlock()
-       t, err := getType(name, rt)
+       t, err := getBaseType(name, rt)
        if err != nil {
                panic("getTypeUnlocked: " + err.String())
        }
@@ -126,27 +126,27 @@ func TestMapType(t *testing.T) {
 }
 
 type Bar struct {
-       x string
+       X string
 }
 
 // This structure has pointers and refers to itself, making it a good test case.
 type Foo struct {
-       a int
-       b int32 // will become int
-       c string
-       d []byte
-       e *float64    // will become float64
-       f ****float64 // will become float64
-       g *Bar
-       h *Bar // should not interpolate the definition of Bar again
-       i *Foo // will not explode
+       A int
+       B int32 // will become int
+       C string
+       D []byte
+       E *float64    // will become float64
+       F ****float64 // will become float64
+       G *Bar
+       H *Bar // should not interpolate the definition of Bar again
+       I *Foo // will not explode
 }
 
 func TestStructType(t *testing.T) {
        sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{}))
        str := sstruct.string()
        // If we can print it correctly, we built it correctly.
-       expected := "Foo = struct { a int; b int; c string; d bytes; e float; f float; g Bar = struct { x string; }; h Bar; i Foo; }"
+       expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }"
        if str != expected {
                t.Errorf("struct printed as %q; expected %q", str, expected)
        }
diff --git a/libgo/go/hash/fnv/fnv.go b/libgo/go/hash/fnv/fnv.go
new file mode 100644 (file)
index 0000000..66ab5a6
--- /dev/null
@@ -0,0 +1,133 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The fnv package implements FNV-1 and FNV-1a,
+// non-cryptographic hash functions created by
+// Glenn Fowler, Landon Curt Noll, and Phong Vo.
+// See http://isthe.com/chongo/tech/comp/fnv/.
+package fnv
+
+import (
+       "encoding/binary"
+       "hash"
+       "os"
+       "unsafe"
+)
+
+type (
+       sum32  uint32
+       sum32a uint32
+       sum64  uint64
+       sum64a uint64
+)
+
+const (
+       offset32 = 2166136261
+       offset64 = 14695981039346656037
+       prime32  = 16777619
+       prime64  = 1099511628211
+)
+
+// New32 returns a new 32-bit FNV-1 hash.Hash.
+func New32() hash.Hash32 {
+       var s sum32 = offset32
+       return &s
+}
+
+// New32a returns a new 32-bit FNV-1a hash.Hash.
+func New32a() hash.Hash32 {
+       var s sum32a = offset32
+       return &s
+}
+
+// New64 returns a new 64-bit FNV-1 hash.Hash.
+func New64() hash.Hash64 {
+       var s sum64 = offset64
+       return &s
+}
+
+// New64a returns a new 64-bit FNV-1a hash.Hash.
+func New64a() hash.Hash64 {
+       var s sum64a = offset64
+       return &s
+}
+
+func (s *sum32) Reset()  { *s = offset32 }
+func (s *sum32a) Reset() { *s = offset32 }
+func (s *sum64) Reset()  { *s = offset64 }
+func (s *sum64a) Reset() { *s = offset64 }
+
+func (s *sum32) Sum32() uint32  { return uint32(*s) }
+func (s *sum32a) Sum32() uint32 { return uint32(*s) }
+func (s *sum64) Sum64() uint64  { return uint64(*s) }
+func (s *sum64a) Sum64() uint64 { return uint64(*s) }
+
+func (s *sum32) Write(data []byte) (int, os.Error) {
+       hash := *s
+       for _, c := range data {
+               hash *= prime32
+               hash ^= sum32(c)
+       }
+       *s = hash
+       return len(data), nil
+}
+
+func (s *sum32a) Write(data []byte) (int, os.Error) {
+       hash := *s
+       for _, c := range data {
+               hash ^= sum32a(c)
+               hash *= prime32
+       }
+       *s = hash
+       return len(data), nil
+}
+
+func (s *sum64) Write(data []byte) (int, os.Error) {
+       hash := *s
+       for _, c := range data {
+               hash *= prime64
+               hash ^= sum64(c)
+       }
+       *s = hash
+       return len(data), nil
+}
+
+func (s *sum64a) Write(data []byte) (int, os.Error) {
+       hash := *s
+       for _, c := range data {
+               hash ^= sum64a(c)
+               hash *= prime64
+       }
+       *s = hash
+       return len(data), nil
+}
+
+func (s *sum32) Size() int  { return unsafe.Sizeof(*s) }
+func (s *sum32a) Size() int { return unsafe.Sizeof(*s) }
+func (s *sum64) Size() int  { return unsafe.Sizeof(*s) }
+func (s *sum64a) Size() int { return unsafe.Sizeof(*s) }
+
+func (s *sum32) Sum() []byte {
+       a := make([]byte, unsafe.Sizeof(*s))
+       binary.BigEndian.PutUint32(a, uint32(*s))
+       return a
+}
+
+func (s *sum32a) Sum() []byte {
+       a := make([]byte, unsafe.Sizeof(*s))
+       binary.BigEndian.PutUint32(a, uint32(*s))
+       return a
+}
+
+func (s *sum64) Sum() []byte {
+       a := make([]byte, unsafe.Sizeof(*s))
+       binary.BigEndian.PutUint64(a, uint64(*s))
+       return a
+}
+
+func (s *sum64a) Sum() []byte {
+       a := make([]byte, unsafe.Sizeof(*s))
+       binary.BigEndian.PutUint64(a, uint64(*s))
+       return a
+}
diff --git a/libgo/go/hash/fnv/fnv_test.go b/libgo/go/hash/fnv/fnv_test.go
new file mode 100644 (file)
index 0000000..3ea3fe6
--- /dev/null
@@ -0,0 +1,167 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fnv
+
+import (
+       "bytes"
+       "encoding/binary"
+       "hash"
+       "testing"
+)
+
+const testDataSize = 40
+
+type golden struct {
+       sum  []byte
+       text string
+}
+
+var golden32 = []golden{
+       {[]byte{0x81, 0x1c, 0x9d, 0xc5}, ""},
+       {[]byte{0x05, 0x0c, 0x5d, 0x7e}, "a"},
+       {[]byte{0x70, 0x77, 0x2d, 0x38}, "ab"},
+       {[]byte{0x43, 0x9c, 0x2f, 0x4b}, "abc"},
+}
+
+var golden32a = []golden{
+       {[]byte{0x81, 0x1c, 0x9d, 0xc5}, ""},
+       {[]byte{0xe4, 0x0c, 0x29, 0x2c}, "a"},
+       {[]byte{0x4d, 0x25, 0x05, 0xca}, "ab"},
+       {[]byte{0x1a, 0x47, 0xe9, 0x0b}, "abc"},
+}
+
+var golden64 = []golden{
+       {[]byte{0xcb, 0xf2, 0x9c, 0xe4, 0x84, 0x22, 0x23, 0x25}, ""},
+       {[]byte{0xaf, 0x63, 0xbd, 0x4c, 0x86, 0x01, 0xb7, 0xbe}, "a"},
+       {[]byte{0x08, 0x32, 0x67, 0x07, 0xb4, 0xeb, 0x37, 0xb8}, "ab"},
+       {[]byte{0xd8, 0xdc, 0xca, 0x18, 0x6b, 0xaf, 0xad, 0xcb}, "abc"},
+}
+
+var golden64a = []golden{
+       {[]byte{0xcb, 0xf2, 0x9c, 0xe4, 0x84, 0x22, 0x23, 0x25}, ""},
+       {[]byte{0xaf, 0x63, 0xdc, 0x4c, 0x86, 0x01, 0xec, 0x8c}, "a"},
+       {[]byte{0x08, 0x9c, 0x44, 0x07, 0xb5, 0x45, 0x98, 0x6a}, "ab"},
+       {[]byte{0xe7, 0x1f, 0xa2, 0x19, 0x05, 0x41, 0x57, 0x4b}, "abc"},
+}
+
+func TestGolden32(t *testing.T) {
+       testGolden(t, New32(), golden32)
+}
+
+func TestGolden32a(t *testing.T) {
+       testGolden(t, New32a(), golden32a)
+}
+
+func TestGolden64(t *testing.T) {
+       testGolden(t, New64(), golden64)
+}
+
+func TestGolden64a(t *testing.T) {
+       testGolden(t, New64a(), golden64a)
+}
+
+func testGolden(t *testing.T, hash hash.Hash, gold []golden) {
+       for _, g := range gold {
+               hash.Reset()
+               done, error := hash.Write([]byte(g.text))
+               if error != nil {
+                       t.Fatalf("write error: %s", error)
+               }
+               if done != len(g.text) {
+                       t.Fatalf("wrote only %d out of %d bytes", done, len(g.text))
+               }
+               if actual := hash.Sum(); !bytes.Equal(g.sum, actual) {
+                       t.Errorf("hash(%q) = 0x%x want 0x%x", g.text, actual, g.sum)
+               }
+       }
+}
+
+func TestIntegrity32(t *testing.T) {
+       testIntegrity(t, New32())
+}
+
+func TestIntegrity32a(t *testing.T) {
+       testIntegrity(t, New32a())
+}
+
+func TestIntegrity64(t *testing.T) {
+       testIntegrity(t, New64())
+}
+
+func TestIntegrity64a(t *testing.T) {
+       testIntegrity(t, New64a())
+}
+
+func testIntegrity(t *testing.T, h hash.Hash) {
+       data := []byte{'1', '2', 3, 4, 5}
+       h.Write(data)
+       sum := h.Sum()
+
+       if size := h.Size(); size != len(sum) {
+               t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum))
+       }
+
+       if a := h.Sum(); !bytes.Equal(sum, a) {
+               t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a)
+       }
+
+       h.Reset()
+       h.Write(data)
+       if a := h.Sum(); !bytes.Equal(sum, a) {
+               t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a)
+       }
+
+       h.Reset()
+       h.Write(data[:2])
+       h.Write(data[2:])
+       if a := h.Sum(); !bytes.Equal(sum, a) {
+               t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a)
+       }
+
+       switch h.Size() {
+       case 4:
+               sum32 := h.(hash.Hash32).Sum32()
+               if sum32 != binary.BigEndian.Uint32(sum) {
+                       t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32)
+               }
+       case 8:
+               sum64 := h.(hash.Hash64).Sum64()
+               if sum64 != binary.BigEndian.Uint64(sum) {
+                       t.Fatalf("Sum()=0x%x, but Sum64()=0x%x", sum, sum64)
+               }
+       }
+}
+
+func Benchmark32(b *testing.B) {
+       benchmark(b, New32())
+}
+
+func Benchmark32a(b *testing.B) {
+       benchmark(b, New32a())
+}
+
+func Benchmark64(b *testing.B) {
+       benchmark(b, New64())
+}
+
+func Benchmark64a(b *testing.B) {
+       benchmark(b, New64a())
+}
+
+func benchmark(b *testing.B, h hash.Hash) {
+       b.ResetTimer()
+       b.SetBytes(testDataSize)
+       data := make([]byte, testDataSize)
+       for i, _ := range data {
+               data[i] = byte(i + 'a')
+       }
+
+       b.StartTimer()
+       for todo := b.N; todo != 0; todo-- {
+               h.Reset()
+               h.Write(data)
+               h.Sum()
+       }
+}
diff --git a/libgo/go/http/cgi/child.go b/libgo/go/http/cgi/child.go
new file mode 100644 (file)
index 0000000..c7d48b9
--- /dev/null
@@ -0,0 +1,192 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements CGI from the perspective of a child
+// process.
+
+package cgi
+
+import (
+       "bufio"
+       "fmt"
+       "http"
+       "io"
+       "io/ioutil"
+       "os"
+       "strconv"
+       "strings"
+)
+
+// Request returns the HTTP request as represented in the current
+// environment. This assumes the current program is being run
+// by a web server in a CGI environment.
+func Request() (*http.Request, os.Error) {
+       return requestFromEnvironment(envMap(os.Environ()))
+}
+
+func envMap(env []string) map[string]string {
+       m := make(map[string]string)
+       for _, kv := range env {
+               if idx := strings.Index(kv, "="); idx != -1 {
+                       m[kv[:idx]] = kv[idx+1:]
+               }
+       }
+       return m
+}
+
+// These environment variables are manually copied into Request
+var skipHeader = map[string]bool{
+       "HTTP_HOST":       true,
+       "HTTP_REFERER":    true,
+       "HTTP_USER_AGENT": true,
+}
+
+func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) {
+       r := new(http.Request)
+       r.Method = env["REQUEST_METHOD"]
+       if r.Method == "" {
+               return nil, os.NewError("cgi: no REQUEST_METHOD in environment")
+       }
+       r.Close = true
+       r.Trailer = http.Header{}
+       r.Header = http.Header{}
+
+       r.Host = env["HTTP_HOST"]
+       r.Referer = env["HTTP_REFERER"]
+       r.UserAgent = env["HTTP_USER_AGENT"]
+
+       // CGI doesn't allow chunked requests, so these should all be accurate:
+       r.Proto = "HTTP/1.0"
+       r.ProtoMajor = 1
+       r.ProtoMinor = 0
+       r.TransferEncoding = nil
+
+       if lenstr := env["CONTENT_LENGTH"]; lenstr != "" {
+               clen, err := strconv.Atoi64(lenstr)
+               if err != nil {
+                       return nil, os.NewError("cgi: bad CONTENT_LENGTH in environment: " + lenstr)
+               }
+               r.ContentLength = clen
+               r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, clen))
+       }
+
+       // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers
+       for k, v := range env {
+               if !strings.HasPrefix(k, "HTTP_") || skipHeader[k] {
+                       continue
+               }
+               r.Header.Add(strings.Replace(k[5:], "_", "-", -1), v)
+       }
+
+       // TODO: cookies.  parsing them isn't exported, though.
+
+       if r.Host != "" {
+               // Hostname is provided, so we can reasonably construct a URL,
+               // even if we have to assume 'http' for the scheme.
+               r.RawURL = "http://" + r.Host + env["REQUEST_URI"]
+               url, err := http.ParseURL(r.RawURL)
+               if err != nil {
+                       return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + r.RawURL)
+               }
+               r.URL = url
+       }
+       // Fallback logic if we don't have a Host header or the URL
+       // failed to parse
+       if r.URL == nil {
+               r.RawURL = env["REQUEST_URI"]
+               url, err := http.ParseURL(r.RawURL)
+               if err != nil {
+                       return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + r.RawURL)
+               }
+               r.URL = url
+       }
+       return r, nil
+}
+
+// Serve executes the provided Handler on the currently active CGI
+// request, if any. If there's no current CGI environment
+// an error is returned. The provided handler may be nil to use
+// http.DefaultServeMux.
+func Serve(handler http.Handler) os.Error {
+       req, err := Request()
+       if err != nil {
+               return err
+       }
+       if handler == nil {
+               handler = http.DefaultServeMux
+       }
+       rw := &response{
+               req:    req,
+               header: make(http.Header),
+               bufw:   bufio.NewWriter(os.Stdout),
+       }
+       handler.ServeHTTP(rw, req)
+       if err = rw.bufw.Flush(); err != nil {
+               return err
+       }
+       return nil
+}
+
+type response struct {
+       req        *http.Request
+       header     http.Header
+       bufw       *bufio.Writer
+       headerSent bool
+}
+
+func (r *response) Flush() {
+       r.bufw.Flush()
+}
+
+func (r *response) RemoteAddr() string {
+       return os.Getenv("REMOTE_ADDR")
+}
+
+func (r *response) Header() http.Header {
+       return r.header
+}
+
+func (r *response) Write(p []byte) (n int, err os.Error) {
+       if !r.headerSent {
+               r.WriteHeader(http.StatusOK)
+       }
+       return r.bufw.Write(p)
+}
+
+func (r *response) WriteHeader(code int) {
+       if r.headerSent {
+               // Note: explicitly using Stderr, as Stdout is our HTTP output.
+               fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL)
+               return
+       }
+       r.headerSent = true
+       fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code))
+
+       // Set a default Content-Type
+       if _, hasType := r.header["Content-Type"]; !hasType {
+               r.header.Add("Content-Type", "text/html; charset=utf-8")
+       }
+
+       // TODO: add a method on http.Header to write itself to an io.Writer?
+       // This is duplicated code.
+       for k, vv := range r.header {
+               for _, v := range vv {
+                       v = strings.Replace(v, "\n", "", -1)
+                       v = strings.Replace(v, "\r", "", -1)
+                       v = strings.TrimSpace(v)
+                       fmt.Fprintf(r.bufw, "%s: %s\r\n", k, v)
+               }
+       }
+       r.bufw.Write([]byte("\r\n"))
+       r.bufw.Flush()
+}
+
+func (r *response) UsingTLS() bool {
+       // There's apparently a de-facto standard for this.
+       // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
+       if s := os.Getenv("HTTPS"); s == "on" || s == "ON" || s == "1" {
+               return true
+       }
+       return false
+}
diff --git a/libgo/go/http/cgi/child_test.go b/libgo/go/http/cgi/child_test.go
new file mode 100644 (file)
index 0000000..db0e09c
--- /dev/null
@@ -0,0 +1,83 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for CGI (the child process perspective)
+
+package cgi
+
+import (
+       "testing"
+)
+
+func TestRequest(t *testing.T) {
+       env := map[string]string{
+               "REQUEST_METHOD":  "GET",
+               "HTTP_HOST":       "example.com",
+               "HTTP_REFERER":    "elsewhere",
+               "HTTP_USER_AGENT": "goclient",
+               "HTTP_FOO_BAR":    "baz",
+               "REQUEST_URI":     "/path?a=b",
+               "CONTENT_LENGTH":  "123",
+       }
+       req, err := requestFromEnvironment(env)
+       if err != nil {
+               t.Fatalf("requestFromEnvironment: %v", err)
+       }
+       if g, e := req.UserAgent, "goclient"; e != g {
+               t.Errorf("expected UserAgent %q; got %q", e, g)
+       }
+       if g, e := req.Method, "GET"; e != g {
+               t.Errorf("expected Method %q; got %q", e, g)
+       }
+       if g, e := req.Header.Get("User-Agent"), ""; e != g {
+               // Tests that we don't put recognized headers in the map
+               t.Errorf("expected User-Agent %q; got %q", e, g)
+       }
+       if g, e := req.ContentLength, int64(123); e != g {
+               t.Errorf("expected ContentLength %d; got %d", e, g)
+       }
+       if g, e := req.Referer, "elsewhere"; e != g {
+               t.Errorf("expected Referer %q; got %q", e, g)
+       }
+       if req.Header == nil {
+               t.Fatalf("unexpected nil Header")
+       }
+       if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g {
+               t.Errorf("expected Foo-Bar %q; got %q", e, g)
+       }
+       if g, e := req.RawURL, "http://example.com/path?a=b"; e != g {
+               t.Errorf("expected RawURL %q; got %q", e, g)
+       }
+       if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g {
+               t.Errorf("expected URL %q; got %q", e, g)
+       }
+       if g, e := req.FormValue("a"), "b"; e != g {
+               t.Errorf("expected FormValue(a) %q; got %q", e, g)
+       }
+       if req.Trailer == nil {
+               t.Errorf("unexpected nil Trailer")
+       }
+}
+
+func TestRequestWithoutHost(t *testing.T) {
+       env := map[string]string{
+               "HTTP_HOST":      "",
+               "REQUEST_METHOD": "GET",
+               "REQUEST_URI":    "/path?a=b",
+               "CONTENT_LENGTH": "123",
+       }
+       req, err := requestFromEnvironment(env)
+       if err != nil {
+               t.Fatalf("requestFromEnvironment: %v", err)
+       }
+       if g, e := req.RawURL, "/path?a=b"; e != g {
+               t.Errorf("expected RawURL %q; got %q", e, g)
+       }
+       if req.URL == nil {
+               t.Fatalf("unexpected nil URL")
+       }
+       if g, e := req.URL.String(), "/path?a=b"; e != g {
+               t.Errorf("expected URL %q; got %q", e, g)
+       }
+}
diff --git a/libgo/go/http/cgi/host.go b/libgo/go/http/cgi/host.go
new file mode 100644 (file)
index 0000000..2272387
--- /dev/null
@@ -0,0 +1,221 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements the host side of CGI (being the webserver
+// parent process).
+
+// Package cgi implements CGI (Common Gateway Interface) as specified
+// in RFC 3875.
+//
+// Note that using CGI means starting a new process to handle each
+// request, which is typically less efficient than using a
+// long-running server.  This package is intended primarily for
+// compatibility with existing systems.
+package cgi
+
+import (
+       "bytes"
+       "encoding/line"
+       "exec"
+       "fmt"
+       "http"
+       "io"
+       "log"
+       "os"
+       "path/filepath"
+       "regexp"
+       "strconv"
+       "strings"
+)
+
+var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
+
+// Handler runs an executable in a subprocess with a CGI environment.
+type Handler struct {
+       Path string // path to the CGI executable
+       Root string // root URI prefix of handler or empty for "/"
+
+       Env    []string    // extra environment variables to set, if any
+       Logger *log.Logger // optional log for errors or nil to use log.Print
+       Args   []string    // optional arguments to pass to child process
+}
+
+func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+       root := h.Root
+       if root == "" {
+               root = "/"
+       }
+
+       if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
+               rw.WriteHeader(http.StatusBadRequest)
+               rw.Write([]byte("Chunked request bodies are not supported by CGI."))
+               return
+       }
+
+       pathInfo := req.URL.Path
+       if root != "/" && strings.HasPrefix(pathInfo, root) {
+               pathInfo = pathInfo[len(root):]
+       }
+
+       port := "80"
+       if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
+               port = matches[1]
+       }
+
+       env := []string{
+               "SERVER_SOFTWARE=go",
+               "SERVER_NAME=" + req.Host,
+               "HTTP_HOST=" + req.Host,
+               "GATEWAY_INTERFACE=CGI/1.1",
+               "REQUEST_METHOD=" + req.Method,
+               "QUERY_STRING=" + req.URL.RawQuery,
+               "REQUEST_URI=" + req.URL.RawPath,
+               "PATH_INFO=" + pathInfo,
+               "SCRIPT_NAME=" + root,
+               "SCRIPT_FILENAME=" + h.Path,
+               "REMOTE_ADDR=" + req.RemoteAddr,
+               "REMOTE_HOST=" + req.RemoteAddr,
+               "SERVER_PORT=" + port,
+       }
+
+       if req.TLS != nil {
+               env = append(env, "HTTPS=on")
+       }
+
+       if len(req.Cookie) > 0 {
+               b := new(bytes.Buffer)
+               for idx, c := range req.Cookie {
+                       if idx > 0 {
+                               b.Write([]byte("; "))
+                       }
+                       fmt.Fprintf(b, "%s=%s", c.Name, c.Value)
+               }
+               env = append(env, "HTTP_COOKIE="+b.String())
+       }
+
+       for k, v := range req.Header {
+               k = strings.Map(upperCaseAndUnderscore, k)
+               env = append(env, "HTTP_"+k+"="+strings.Join(v, ", "))
+       }
+
+       if req.ContentLength > 0 {
+               env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength))
+       }
+       if ctype := req.Header.Get("Content-Type"); ctype != "" {
+               env = append(env, "CONTENT_TYPE="+ctype)
+       }
+
+       if h.Env != nil {
+               env = append(env, h.Env...)
+       }
+
+       cwd, pathBase := filepath.Split(h.Path)
+       if cwd == "" {
+               cwd = "."
+       }
+
+       args := []string{h.Path}
+       args = append(args, h.Args...)
+
+       cmd, err := exec.Run(
+               pathBase,
+               args,
+               env,
+               cwd,
+               exec.Pipe,        // stdin
+               exec.Pipe,        // stdout
+               exec.PassThrough, // stderr (for now)
+       )
+       if err != nil {
+               rw.WriteHeader(http.StatusInternalServerError)
+               h.printf("CGI error: %v", err)
+               return
+       }
+       defer func() {
+               cmd.Stdin.Close()
+               cmd.Stdout.Close()
+               cmd.Wait(0) // no zombies
+       }()
+
+       if req.ContentLength != 0 {
+               go io.Copy(cmd.Stdin, req.Body)
+       }
+
+       linebody := line.NewReader(cmd.Stdout, 1024)
+       headers := rw.Header()
+       statusCode := http.StatusOK
+       for {
+               line, isPrefix, err := linebody.ReadLine()
+               if isPrefix {
+                       rw.WriteHeader(http.StatusInternalServerError)
+                       h.printf("CGI: long header line from subprocess.")
+                       return
+               }
+               if err == os.EOF {
+                       break
+               }
+               if err != nil {
+                       rw.WriteHeader(http.StatusInternalServerError)
+                       h.printf("CGI: error reading headers: %v", err)
+                       return
+               }
+               if len(line) == 0 {
+                       break
+               }
+               parts := strings.Split(string(line), ":", 2)
+               if len(parts) < 2 {
+                       h.printf("CGI: bogus header line: %s", string(line))
+                       continue
+               }
+               header, val := parts[0], parts[1]
+               header = strings.TrimSpace(header)
+               val = strings.TrimSpace(val)
+               switch {
+               case header == "Status":
+                       if len(val) < 3 {
+                               h.printf("CGI: bogus status (short): %q", val)
+                               return
+                       }
+                       code, err := strconv.Atoi(val[0:3])
+                       if err != nil {
+                               h.printf("CGI: bogus status: %q", val)
+                               h.printf("CGI: line was %q", line)
+                               return
+                       }
+                       statusCode = code
+               default:
+                       headers.Add(header, val)
+               }
+       }
+       rw.WriteHeader(statusCode)
+
+       _, err = io.Copy(rw, linebody)
+       if err != nil {
+               h.printf("CGI: copy error: %v", err)
+       }
+}
+
+func (h *Handler) printf(format string, v ...interface{}) {
+       if h.Logger != nil {
+               h.Logger.Printf(format, v...)
+       } else {
+               log.Printf(format, v...)
+       }
+}
+
+func upperCaseAndUnderscore(rune int) int {
+       switch {
+       case rune >= 'a' && rune <= 'z':
+               return rune - ('a' - 'A')
+       case rune == '-':
+               return '_'
+       case rune == '=':
+               // Maybe not part of the CGI 'spec' but would mess up
+               // the environment in any case, as Go represents the
+               // environment as a slice of "key=value" strings.
+               return '_'
+       }
+       // TODO: other transformations in spec or practice?
+       return rune
+}
diff --git a/libgo/go/http/cgi/host_test.go b/libgo/go/http/cgi/host_test.go
new file mode 100644 (file)
index 0000000..e8084b1
--- /dev/null
@@ -0,0 +1,273 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for package cgi
+
+package cgi
+
+import (
+       "bufio"
+       "exec"
+       "fmt"
+       "http"
+       "http/httptest"
+       "os"
+       "strings"
+       "testing"
+)
+
+var cgiScriptWorks = canRun("./testdata/test.cgi")
+
+func canRun(s string) bool {
+       c, err := exec.Run(s, []string{s}, nil, ".", exec.DevNull, exec.DevNull, exec.DevNull)
+       if err != nil {
+               return false
+       }
+       w, err := c.Wait(0)
+       if err != nil {
+               return false
+       }
+       return w.Exited() && w.ExitStatus() == 0
+}
+
+func newRequest(httpreq string) *http.Request {
+       buf := bufio.NewReader(strings.NewReader(httpreq))
+       req, err := http.ReadRequest(buf)
+       if err != nil {
+               panic("cgi: bogus http request in test: " + httpreq)
+       }
+       req.RemoteAddr = "1.2.3.4"
+       return req
+}
+
+func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder {
+       rw := httptest.NewRecorder()
+       req := newRequest(httpreq)
+       h.ServeHTTP(rw, req)
+
+       // Make a map to hold the test map that the CGI returns.
+       m := make(map[string]string)
+       linesRead := 0
+readlines:
+       for {
+               line, err := rw.Body.ReadString('\n')
+               switch {
+               case err == os.EOF:
+                       break readlines
+               case err != nil:
+                       t.Fatalf("unexpected error reading from CGI: %v", err)
+               }
+               linesRead++
+               trimmedLine := strings.TrimRight(line, "\r\n")
+               split := strings.Split(trimmedLine, "=", 2)
+               if len(split) != 2 {
+                       t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v",
+                               len(split), linesRead, line, m)
+               }
+               m[split[0]] = split[1]
+       }
+
+       for key, expected := range expectedMap {
+               if got := m[key]; got != expected {
+                       t.Errorf("for key %q got %q; expected %q", key, got, expected)
+               }
+       }
+       return rw
+}
+
+func skipTest(t *testing.T) bool {
+       if !cgiScriptWorks {
+               // No Perl on Windows, needed by test.cgi
+               // TODO: make the child process be Go, not Perl.
+               t.Logf("Skipping test: test.cgi failed.")
+               return true
+       }
+       return false
+}
+
+
+func TestCGIBasicGet(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       h := &Handler{
+               Path: "testdata/test.cgi",
+               Root: "/test.cgi",
+       }
+       expectedMap := map[string]string{
+               "test":                  "Hello CGI",
+               "param-a":               "b",
+               "param-foo":             "bar",
+               "env-GATEWAY_INTERFACE": "CGI/1.1",
+               "env-HTTP_HOST":         "example.com",
+               "env-PATH_INFO":         "",
+               "env-QUERY_STRING":      "foo=bar&a=b",
+               "env-REMOTE_ADDR":       "1.2.3.4",
+               "env-REMOTE_HOST":       "1.2.3.4",
+               "env-REQUEST_METHOD":    "GET",
+               "env-REQUEST_URI":       "/test.cgi?foo=bar&a=b",
+               "env-SCRIPT_FILENAME":   "testdata/test.cgi",
+               "env-SCRIPT_NAME":       "/test.cgi",
+               "env-SERVER_NAME":       "example.com",
+               "env-SERVER_PORT":       "80",
+               "env-SERVER_SOFTWARE":   "go",
+       }
+       replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+       if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected {
+               t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+       }
+       if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+               t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+       }
+}
+
+func TestCGIBasicGetAbsPath(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       pwd, err := os.Getwd()
+       if err != nil {
+               t.Fatalf("getwd error: %v", err)
+       }
+       h := &Handler{
+               Path: pwd + "/testdata/test.cgi",
+               Root: "/test.cgi",
+       }
+       expectedMap := map[string]string{
+               "env-REQUEST_URI":     "/test.cgi?foo=bar&a=b",
+               "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi",
+               "env-SCRIPT_NAME":     "/test.cgi",
+       }
+       runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfo(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       h := &Handler{
+               Path: "testdata/test.cgi",
+               Root: "/test.cgi",
+       }
+       expectedMap := map[string]string{
+               "param-a":             "b",
+               "env-PATH_INFO":       "/extrapath",
+               "env-QUERY_STRING":    "a=b",
+               "env-REQUEST_URI":     "/test.cgi/extrapath?a=b",
+               "env-SCRIPT_FILENAME": "testdata/test.cgi",
+               "env-SCRIPT_NAME":     "/test.cgi",
+       }
+       runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfoDirRoot(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       h := &Handler{
+               Path: "testdata/test.cgi",
+               Root: "/myscript/",
+       }
+       expectedMap := map[string]string{
+               "env-PATH_INFO":       "bar",
+               "env-QUERY_STRING":    "a=b",
+               "env-REQUEST_URI":     "/myscript/bar?a=b",
+               "env-SCRIPT_FILENAME": "testdata/test.cgi",
+               "env-SCRIPT_NAME":     "/myscript/",
+       }
+       runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestDupHeaders(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       h := &Handler{
+               Path: "testdata/test.cgi",
+       }
+       expectedMap := map[string]string{
+               "env-REQUEST_URI":     "/myscript/bar?a=b",
+               "env-SCRIPT_FILENAME": "testdata/test.cgi",
+               "env-HTTP_COOKIE":     "nom=NOM; yum=YUM",
+               "env-HTTP_X_FOO":      "val1, val2",
+       }
+       runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
+               "Cookie: nom=NOM\n"+
+               "Cookie: yum=YUM\n"+
+               "X-Foo: val1\n"+
+               "X-Foo: val2\n"+
+               "Host: example.com\n\n",
+               expectedMap)
+}
+
+func TestPathInfoNoRoot(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       h := &Handler{
+               Path: "testdata/test.cgi",
+               Root: "",
+       }
+       expectedMap := map[string]string{
+               "env-PATH_INFO":       "/bar",
+               "env-QUERY_STRING":    "a=b",
+               "env-REQUEST_URI":     "/bar?a=b",
+               "env-SCRIPT_FILENAME": "testdata/test.cgi",
+               "env-SCRIPT_NAME":     "/",
+       }
+       runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestCGIBasicPost(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       postReq := `POST /test.cgi?a=b HTTP/1.0
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Content-Length: 15
+
+postfoo=postbar`
+       h := &Handler{
+               Path: "testdata/test.cgi",
+               Root: "/test.cgi",
+       }
+       expectedMap := map[string]string{
+               "test":               "Hello CGI",
+               "param-postfoo":      "postbar",
+               "env-REQUEST_METHOD": "POST",
+               "env-CONTENT_LENGTH": "15",
+               "env-REQUEST_URI":    "/test.cgi?a=b",
+       }
+       runCgiTest(t, h, postReq, expectedMap)
+}
+
+func chunk(s string) string {
+       return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+// The CGI spec doesn't allow chunked requests.
+func TestCGIPostChunked(t *testing.T) {
+       if skipTest(t) {
+               return
+       }
+       postReq := `POST /test.cgi?a=b HTTP/1.1
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Transfer-Encoding: chunked
+
+` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("")
+
+       h := &Handler{
+               Path: "testdata/test.cgi",
+               Root: "/test.cgi",
+       }
+       expectedMap := map[string]string{}
+       resp := runCgiTest(t, h, postReq, expectedMap)
+       if got, expected := resp.Code, http.StatusBadRequest; got != expected {
+               t.Fatalf("Expected %v response code from chunked request body; got %d",
+                       expected, got)
+       }
+}
diff --git a/libgo/go/http/cgi/matryoshka_test.go b/libgo/go/http/cgi/matryoshka_test.go
new file mode 100644 (file)
index 0000000..3e4a6ad
--- /dev/null
@@ -0,0 +1,74 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests a Go CGI program running under a Go CGI host process.
+// Further, the two programs are the same binary, just checking
+// their environment to figure out what mode to run in.
+
+package cgi
+
+import (
+       "fmt"
+       "http"
+       "os"
+       "testing"
+)
+
+// This test is a CGI host (testing host.go) that runs its own binary
+// as a child process testing the other half of CGI (child.go).
+func TestHostingOurselves(t *testing.T) {
+       h := &Handler{
+               Path: os.Args[0],
+               Root: "/test.go",
+               Args: []string{"-test.run=TestBeChildCGIProcess"},
+       }
+       expectedMap := map[string]string{
+               "test":                  "Hello CGI-in-CGI",
+               "param-a":               "b",
+               "param-foo":             "bar",
+               "env-GATEWAY_INTERFACE": "CGI/1.1",
+               "env-HTTP_HOST":         "example.com",
+               "env-PATH_INFO":         "",
+               "env-QUERY_STRING":      "foo=bar&a=b",
+               "env-REMOTE_ADDR":       "1.2.3.4",
+               "env-REMOTE_HOST":       "1.2.3.4",
+               "env-REQUEST_METHOD":    "GET",
+               "env-REQUEST_URI":       "/test.go?foo=bar&a=b",
+               "env-SCRIPT_FILENAME":   os.Args[0],
+               "env-SCRIPT_NAME":       "/test.go",
+               "env-SERVER_NAME":       "example.com",
+               "env-SERVER_PORT":       "80",
+               "env-SERVER_SOFTWARE":   "go",
+       }
+       replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+       if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected {
+               t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+       }
+       if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+               t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+       }
+}
+
+// Note: not actually a test.
+func TestBeChildCGIProcess(t *testing.T) {
+       if os.Getenv("REQUEST_METHOD") == "" {
+               // Not in a CGI environment; skipping test.
+               return
+       }
+       Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+               rw.Header().Set("X-Test-Header", "X-Test-Value")
+               fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
+               req.ParseForm()
+               for k, vv := range req.Form {
+                       for _, v := range vv {
+                               fmt.Fprintf(rw, "param-%s=%s\n", k, v)
+                       }
+               }
+               for _, kv := range os.Environ() {
+                       fmt.Fprintf(rw, "env-%s\n", kv)
+               }
+       }))
+       os.Exit(0)
+}
index b1fe5ec67803454a823dd37cb374e1f2a4d86b00..daba3a89b0c21effce0c6f9e4f85c3d97afb6968 100644 (file)
@@ -11,6 +11,7 @@ import (
        "encoding/base64"
        "fmt"
        "io"
+       "io/ioutil"
        "os"
        "strconv"
        "strings"
@@ -20,26 +21,28 @@ import (
 // that uses DefaultTransport.
 // Client is not yet very configurable.
 type Client struct {
-       Transport ClientTransport // if nil, DefaultTransport is used
+       Transport RoundTripper // if nil, DefaultTransport is used
 }
 
 // DefaultClient is the default Client and is used by Get, Head, and Post.
 var DefaultClient = &Client{}
 
-// ClientTransport is an interface representing the ability to execute a
+// RoundTripper is an interface representing the ability to execute a
 // single HTTP transaction, obtaining the Response for a given Request.
-type ClientTransport interface {
-       // Do executes a single HTTP transaction, returning the Response for the
-       // request req.  Do should not attempt to interpret the response.
-       // In particular, Do must return err == nil if it obtained a response,
-       // regardless of the response's HTTP status code.  A non-nil err should
-       // be reserved for failure to obtain a response.  Similarly, Do should
-       // not attempt to handle higher-level protocol details such as redirects,
+type RoundTripper interface {
+       // RoundTrip executes a single HTTP transaction, returning
+       // the Response for the request req.  RoundTrip should not
+       // attempt to interpret the response.  In particular,
+       // RoundTrip must return err == nil if it obtained a response,
+       // regardless of the response's HTTP status code.  A non-nil
+       // err should be reserved for failure to obtain a response.
+       // Similarly, RoundTrip should not attempt to handle
+       // higher-level protocol details such as redirects,
        // authentication, or cookies.
        //
-       // Transports may modify the request. The request Headers field is
-       // guaranteed to be initalized.
-       Do(req *Request) (resp *Response, err os.Error)
+       // RoundTrip may modify the request. The request Headers field is
+       // guaranteed to be initialized.
+       RoundTrip(req *Request) (resp *Response, err os.Error)
 }
 
 // Given a string of the form "host", "host:port", or "[ipv6::address]:port",
@@ -54,40 +57,6 @@ type readClose struct {
        io.Closer
 }
 
-// matchNoProxy returns true if requests to addr should not use a proxy,
-// according to the NO_PROXY or no_proxy environment variable.
-func matchNoProxy(addr string) bool {
-       if len(addr) == 0 {
-               return false
-       }
-       no_proxy := os.Getenv("NO_PROXY")
-       if len(no_proxy) == 0 {
-               no_proxy = os.Getenv("no_proxy")
-       }
-       if no_proxy == "*" {
-               return true
-       }
-
-       addr = strings.ToLower(strings.TrimSpace(addr))
-       if hasPort(addr) {
-               addr = addr[:strings.LastIndex(addr, ":")]
-       }
-
-       for _, p := range strings.Split(no_proxy, ",", -1) {
-               p = strings.ToLower(strings.TrimSpace(p))
-               if len(p) == 0 {
-                       continue
-               }
-               if hasPort(p) {
-                       p = p[:strings.LastIndex(p, ":")]
-               }
-               if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
-                       return true
-               }
-       }
-       return false
-}
-
 // Do sends an HTTP request and returns an HTTP response, following
 // policy (e.g. redirects, cookies, auth) as configured on the client.
 //
@@ -100,11 +69,7 @@ func (c *Client) Do(req *Request) (resp *Response, err os.Error) {
 
 
 // send issues an HTTP request.  Caller should close resp.Body when done reading from it.
-//
-// TODO: support persistent connections (multiple requests on a single connection).
-// send() method is nonpublic because, when we refactor the code for persistent
-// connections, it may no longer make sense to have a method with this signature.
-func send(req *Request, t ClientTransport) (resp *Response, err os.Error) {
+func send(req *Request, t RoundTripper) (resp *Response, err os.Error) {
        if t == nil {
                t = DefaultTransport
                if t == nil {
@@ -115,9 +80,9 @@ func send(req *Request, t ClientTransport) (resp *Response, err os.Error) {
 
        // Most the callers of send (Get, Post, et al) don't need
        // Headers, leaving it uninitialized.  We guarantee to the
-       // ClientTransport that this has been initialized, though.
+       // Transport that this has been initialized, though.
        if req.Header == nil {
-               req.Header = Header(make(map[string][]string))
+               req.Header = make(Header)
        }
 
        info := req.URL.RawUserinfo
@@ -130,7 +95,7 @@ func send(req *Request, t ClientTransport) (resp *Response, err os.Error) {
                }
                req.Header.Set("Authorization", "Basic "+string(encoded))
        }
-       return t.Do(req)
+       return t.RoundTrip(req)
 }
 
 // True if the specified HTTP status code is one for which the Get utility should
@@ -237,7 +202,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response,
        req.ProtoMajor = 1
        req.ProtoMinor = 1
        req.Close = true
-       req.Body = nopCloser{body}
+       req.Body = ioutil.NopCloser(body)
        req.Header = Header{
                "Content-Type": {bodyType},
        }
@@ -272,7 +237,7 @@ func (c *Client) PostForm(url string, data map[string]string) (r *Response, err
        req.ProtoMinor = 1
        req.Close = true
        body := urlencode(data)
-       req.Body = nopCloser{body}
+       req.Body = ioutil.NopCloser(body)
        req.Header = Header{
                "Content-Type":   {"application/x-www-form-urlencoded"},
                "Content-Length": {strconv.Itoa(body.Len())},
@@ -312,9 +277,3 @@ func (c *Client) Head(url string) (r *Response, err os.Error) {
        }
        return send(&req, c.Transport)
 }
-
-type nopCloser struct {
-       io.Reader
-}
-
-func (nopCloser) Close() os.Error { return nil }
index c89ecbce2d0d1a1a8fe422e775857adaebfa3bc9..3a6f834253b3fef59b693a8a176df63ddb939ede 100644 (file)
@@ -4,20 +4,28 @@
 
 // Tests for client.go
 
-package http
+package http_test
 
 import (
+       "fmt"
+       . "http"
+       "http/httptest"
        "io/ioutil"
        "os"
        "strings"
        "testing"
 )
 
+var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+       w.Header().Set("Last-Modified", "sometime")
+       fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
+})
+
 func TestClient(t *testing.T) {
-       // TODO: add a proper test suite.  Current test merely verifies that
-       // we can retrieve the Google robots.txt file.
+       ts := httptest.NewServer(robotsTxtHandler)
+       defer ts.Close()
 
-       r, _, err := Get("http://www.google.com/robots.txt")
+       r, _, err := Get(ts.URL)
        var b []byte
        if err == nil {
                b, err = ioutil.ReadAll(r.Body)
@@ -31,7 +39,10 @@ func TestClient(t *testing.T) {
 }
 
 func TestClientHead(t *testing.T) {
-       r, err := Head("http://www.google.com/robots.txt")
+       ts := httptest.NewServer(robotsTxtHandler)
+       defer ts.Close()
+
+       r, err := Head(ts.URL)
        if err != nil {
                t.Fatal(err)
        }
@@ -44,7 +55,7 @@ type recordingTransport struct {
        req *Request
 }
 
-func (t *recordingTransport) Do(req *Request) (resp *Response, err os.Error) {
+func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err os.Error) {
        t.req = req
        return nil, os.NewError("dummy impl")
 }
diff --git a/libgo/go/http/cookie.go b/libgo/go/http/cookie.go
new file mode 100644 (file)
index 0000000..2bb66e5
--- /dev/null
@@ -0,0 +1,272 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "bytes"
+       "fmt"
+       "io"
+       "os"
+       "sort"
+       "strconv"
+       "strings"
+       "time"
+)
+
+// This implementation is done according to IETF draft-ietf-httpstate-cookie-23, found at
+//
+//    http://tools.ietf.org/html/draft-ietf-httpstate-cookie-23
+
+// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
+// HTTP response or the Cookie header of an HTTP request.
+type Cookie struct {
+       Name       string
+       Value      string
+       Path       string
+       Domain     string
+       Expires    time.Time
+       RawExpires string
+
+       // MaxAge=0 means no 'Max-Age' attribute specified. 
+       // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
+       // MaxAge>0 means Max-Age attribute present and given in seconds
+       MaxAge   int
+       Secure   bool
+       HttpOnly bool
+       Raw      string
+       Unparsed []string // Raw text of unparsed attribute-value pairs
+}
+
+// readSetCookies parses all "Set-Cookie" values from
+// the header h, removes the successfully parsed values from the 
+// "Set-Cookie" key in h and returns the parsed Cookies.
+func readSetCookies(h Header) []*Cookie {
+       cookies := []*Cookie{}
+       var unparsedLines []string
+       for _, line := range h["Set-Cookie"] {
+               parts := strings.Split(strings.TrimSpace(line), ";", -1)
+               if len(parts) == 1 && parts[0] == "" {
+                       continue
+               }
+               parts[0] = strings.TrimSpace(parts[0])
+               j := strings.Index(parts[0], "=")
+               if j < 0 {
+                       unparsedLines = append(unparsedLines, line)
+                       continue
+               }
+               name, value := parts[0][:j], parts[0][j+1:]
+               if !isCookieNameValid(name) {
+                       unparsedLines = append(unparsedLines, line)
+                       continue
+               }
+               value, success := parseCookieValue(value)
+               if !success {
+                       unparsedLines = append(unparsedLines, line)
+                       continue
+               }
+               c := &Cookie{
+                       Name:  name,
+                       Value: value,
+                       Raw:   line,
+               }
+               for i := 1; i < len(parts); i++ {
+                       parts[i] = strings.TrimSpace(parts[i])
+                       if len(parts[i]) == 0 {
+                               continue
+                       }
+
+                       attr, val := parts[i], ""
+                       if j := strings.Index(attr, "="); j >= 0 {
+                               attr, val = attr[:j], attr[j+1:]
+                       }
+                       val, success = parseCookieValue(val)
+                       if !success {
+                               c.Unparsed = append(c.Unparsed, parts[i])
+                               continue
+                       }
+                       switch strings.ToLower(attr) {
+                       case "secure":
+                               c.Secure = true
+                               continue
+                       case "httponly":
+                               c.HttpOnly = true
+                               continue
+                       case "domain":
+                               c.Domain = val
+                               // TODO: Add domain parsing
+                               continue
+                       case "max-age":
+                               secs, err := strconv.Atoi(val)
+                               if err != nil || secs < 0 || secs != 0 && val[0] == '0' {
+                                       break
+                               }
+                               if secs <= 0 {
+                                       c.MaxAge = -1
+                               } else {
+                                       c.MaxAge = secs
+                               }
+                               continue
+                       case "expires":
+                               c.RawExpires = val
+                               exptime, err := time.Parse(time.RFC1123, val)
+                               if err != nil {
+                                       c.Expires = time.Time{}
+                                       break
+                               }
+                               c.Expires = *exptime
+                               continue
+                       case "path":
+                               c.Path = val
+                               // TODO: Add path parsing
+                               continue
+                       }
+                       c.Unparsed = append(c.Unparsed, parts[i])
+               }
+               cookies = append(cookies, c)
+       }
+       h["Set-Cookie"] = unparsedLines, unparsedLines != nil
+       return cookies
+}
+
+// writeSetCookies writes the wire representation of the set-cookies
+// to w. Each cookie is written on a separate "Set-Cookie: " line.
+// This choice is made because HTTP parsers tend to have a limit on
+// line-length, so it seems safer to place cookies on separate lines.
+func writeSetCookies(w io.Writer, kk []*Cookie) os.Error {
+       if kk == nil {
+               return nil
+       }
+       lines := make([]string, 0, len(kk))
+       var b bytes.Buffer
+       for _, c := range kk {
+               b.Reset()
+               fmt.Fprintf(&b, "%s=%s", c.Name, c.Value)
+               if len(c.Path) > 0 {
+                       fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path))
+               }
+               if len(c.Domain) > 0 {
+                       fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain))
+               }
+               if len(c.Expires.Zone) > 0 {
+                       fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123))
+               }
+               if c.MaxAge > 0 {
+                       fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge)
+               } else if c.MaxAge < 0 {
+                       fmt.Fprintf(&b, "; Max-Age=0")
+               }
+               if c.HttpOnly {
+                       fmt.Fprintf(&b, "; HttpOnly")
+               }
+               if c.Secure {
+                       fmt.Fprintf(&b, "; Secure")
+               }
+               lines = append(lines, "Set-Cookie: "+b.String()+"\r\n")
+       }
+       sort.SortStrings(lines)
+       for _, l := range lines {
+               if _, err := io.WriteString(w, l); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+// readCookies parses all "Cookie" values from
+// the header h, removes the successfully parsed values from the 
+// "Cookie" key in h and returns the parsed Cookies.
+func readCookies(h Header) []*Cookie {
+       cookies := []*Cookie{}
+       lines, ok := h["Cookie"]
+       if !ok {
+               return cookies
+       }
+       unparsedLines := []string{}
+       for _, line := range lines {
+               parts := strings.Split(strings.TrimSpace(line), ";", -1)
+               if len(parts) == 1 && parts[0] == "" {
+                       continue
+               }
+               // Per-line attributes
+               parsedPairs := 0
+               for i := 0; i < len(parts); i++ {
+                       parts[i] = strings.TrimSpace(parts[i])
+                       if len(parts[i]) == 0 {
+                               continue
+                       }
+                       attr, val := parts[i], ""
+                       if j := strings.Index(attr, "="); j >= 0 {
+                               attr, val = attr[:j], attr[j+1:]
+                       }
+                       if !isCookieNameValid(attr) {
+                               continue
+                       }
+                       val, success := parseCookieValue(val)
+                       if !success {
+                               continue
+                       }
+                       cookies = append(cookies, &Cookie{Name: attr, Value: val})
+                       parsedPairs++
+               }
+               if parsedPairs == 0 {
+                       unparsedLines = append(unparsedLines, line)
+               }
+       }
+       h["Cookie"] = unparsedLines, len(unparsedLines) > 0
+       return cookies
+}
+
+// writeCookies writes the wire representation of the cookies
+// to w. Each cookie is written on a separate "Cookie: " line.
+// This choice is made because HTTP parsers tend to have a limit on
+// line-length, so it seems safer to place cookies on separate lines.
+func writeCookies(w io.Writer, kk []*Cookie) os.Error {
+       lines := make([]string, 0, len(kk))
+       for _, c := range kk {
+               lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value))
+       }
+       sort.SortStrings(lines)
+       for _, l := range lines {
+               if _, err := io.WriteString(w, l); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func unquoteCookieValue(v string) string {
+       if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' {
+               return v[1 : len(v)-1]
+       }
+       return v
+}
+
+func isCookieByte(c byte) bool {
+       switch true {
+       case c == 0x21, 0x23 <= c && c <= 0x2b, 0x2d <= c && c <= 0x3a,
+               0x3c <= c && c <= 0x5b, 0x5d <= c && c <= 0x7e:
+               return true
+       }
+       return false
+}
+
+func parseCookieValue(raw string) (string, bool) {
+       raw = unquoteCookieValue(raw)
+       for i := 0; i < len(raw); i++ {
+               if !isCookieByte(raw[i]) {
+                       return "", false
+               }
+       }
+       return raw, true
+}
+
+func isCookieNameValid(raw string) bool {
+       for _, c := range raw {
+               if !isToken(byte(c)) {
+                       return false
+               }
+       }
+       return true
+}
diff --git a/libgo/go/http/cookie_test.go b/libgo/go/http/cookie_test.go
new file mode 100644 (file)
index 0000000..db09970
--- /dev/null
@@ -0,0 +1,110 @@
+// Copyright 2010 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "bytes"
+       "fmt"
+       "json"
+       "reflect"
+       "testing"
+)
+
+
+var writeSetCookiesTests = []struct {
+       Cookies []*Cookie
+       Raw     string
+}{
+       {
+               []*Cookie{
+                       &Cookie{Name: "cookie-1", Value: "v$1"},
+                       &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600},
+               },
+               "Set-Cookie: cookie-1=v$1\r\n" +
+                       "Set-Cookie: cookie-2=two; Max-Age=3600\r\n",
+       },
+}
+
+func TestWriteSetCookies(t *testing.T) {
+       for i, tt := range writeSetCookiesTests {
+               var w bytes.Buffer
+               writeSetCookies(&w, tt.Cookies)
+               seen := string(w.Bytes())
+               if seen != tt.Raw {
+                       t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, seen)
+                       continue
+               }
+       }
+}
+
+var writeCookiesTests = []struct {
+       Cookies []*Cookie
+       Raw     string
+}{
+       {
+               []*Cookie{&Cookie{Name: "cookie-1", Value: "v$1"}},
+               "Cookie: cookie-1=v$1\r\n",
+       },
+}
+
+func TestWriteCookies(t *testing.T) {
+       for i, tt := range writeCookiesTests {
+               var w bytes.Buffer
+               writeCookies(&w, tt.Cookies)
+               seen := string(w.Bytes())
+               if seen != tt.Raw {
+                       t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, seen)
+                       continue
+               }
+       }
+}
+
+var readSetCookiesTests = []struct {
+       Header  Header
+       Cookies []*Cookie
+}{
+       {
+               Header{"Set-Cookie": {"Cookie-1=v$1"}},
+               []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}},
+       },
+}
+
+func toJSON(v interface{}) string {
+       b, err := json.Marshal(v)
+       if err != nil {
+               return fmt.Sprintf("%#v", v)
+       }
+       return string(b)
+}
+
+func TestReadSetCookies(t *testing.T) {
+       for i, tt := range readSetCookiesTests {
+               c := readSetCookies(tt.Header)
+               if !reflect.DeepEqual(c, tt.Cookies) {
+                       t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies))
+                       continue
+               }
+       }
+}
+
+var readCookiesTests = []struct {
+       Header  Header
+       Cookies []*Cookie
+}{
+       {
+               Header{"Cookie": {"Cookie-1=v$1"}},
+               []*Cookie{&Cookie{Name: "Cookie-1", Value: "v$1"}},
+       },
+}
+
+func TestReadCookies(t *testing.T) {
+       for i, tt := range readCookiesTests {
+               c := readCookies(tt.Header)
+               if !reflect.DeepEqual(c, tt.Cookies) {
+                       t.Errorf("#%d readCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies))
+                       continue
+               }
+       }
+}
index 73ac97973999d4a6476665077b07c8faeffbbe8a..306c45bc2c9398c9281aadc9eb8aec0d94fd56ad 100644 (file)
@@ -7,10 +7,10 @@ package http
 import (
        "bytes"
        "io"
+       "io/ioutil"
        "os"
 )
 
-
 // One of the copies, say from b to r2, could be avoided by using a more
 // elaborate trick where the other copy is made during Request/Response.Write.
 // This would complicate things too much, given that these functions are for
@@ -23,7 +23,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) {
        if err = b.Close(); err != nil {
                return nil, nil, err
        }
-       return nopCloser{&buf}, nopCloser{bytes.NewBuffer(buf.Bytes())}, nil
+       return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewBuffer(buf.Bytes())), nil
 }
 
 // DumpRequest returns the wire representation of req,
diff --git a/libgo/go/http/export_test.go b/libgo/go/http/export_test.go
new file mode 100644 (file)
index 0000000..a76b707
--- /dev/null
@@ -0,0 +1,34 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Bridge package to expose http internals to tests in the http_test
+// package.
+
+package http
+
+func (t *Transport) IdleConnKeysForTesting() (keys []string) {
+       keys = make([]string, 0)
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               return
+       }
+       for key, _ := range t.idleConn {
+               keys = append(keys, key)
+       }
+       return
+}
+
+func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               return 0
+       }
+       conns, ok := t.idleConn[cacheKey]
+       if !ok {
+               return 0
+       }
+       return len(conns)
+}
index 8e16992e0f088839f5070db37485a6b898682aa5..4ad680ccc314bdbd9948ca396f25deace3d4c3bc 100644 (file)
@@ -11,7 +11,7 @@ import (
        "io"
        "mime"
        "os"
-       "path"
+       "path/filepath"
        "strconv"
        "strings"
        "time"
@@ -108,11 +108,11 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) {
                w.WriteHeader(StatusNotModified)
                return
        }
-       w.SetHeader("Last-Modified", time.SecondsToUTC(d.Mtime_ns/1e9).Format(TimeFormat))
+       w.Header().Set("Last-Modified", time.SecondsToUTC(d.Mtime_ns/1e9).Format(TimeFormat))
 
        // use contents of index.html for directory, if present
        if d.IsDirectory() {
-               index := name + indexPage
+               index := name + filepath.FromSlash(indexPage)
                ff, err := os.Open(index, os.O_RDONLY, 0)
                if err == nil {
                        defer ff.Close()
@@ -135,18 +135,18 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) {
        code := StatusOK
 
        // use extension to find content type.
-       ext := path.Ext(name)
+       ext := filepath.Ext(name)
        if ctype := mime.TypeByExtension(ext); ctype != "" {
-               w.SetHeader("Content-Type", ctype)
+               w.Header().Set("Content-Type", ctype)
        } else {
                // read first chunk to decide between utf-8 text and binary
                var buf [1024]byte
                n, _ := io.ReadFull(f, buf[:])
                b := buf[:n]
                if isText(b) {
-                       w.SetHeader("Content-Type", "text-plain; charset=utf-8")
+                       w.Header().Set("Content-Type", "text-plain; charset=utf-8")
                } else {
-                       w.SetHeader("Content-Type", "application/octet-stream") // generic binary
+                       w.Header().Set("Content-Type", "application/octet-stream") // generic binary
                }
                f.Seek(0, 0) // rewind to output whole file
        }
@@ -166,11 +166,11 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) {
                }
                size = ra.length
                code = StatusPartialContent
-               w.SetHeader("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, d.Size))
+               w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, d.Size))
        }
 
-       w.SetHeader("Accept-Ranges", "bytes")
-       w.SetHeader("Content-Length", strconv.Itoa64(size))
+       w.Header().Set("Accept-Ranges", "bytes")
+       w.Header().Set("Content-Length", strconv.Itoa64(size))
 
        w.WriteHeader(code)
 
@@ -202,7 +202,7 @@ func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
                return
        }
        path = path[len(f.prefix):]
-       serveFile(w, r, f.root+"/"+path, true)
+       serveFile(w, r, filepath.Join(f.root, filepath.FromSlash(path)), true)
 }
 
 // httpRange specifies the byte range to be sent to the client.
index a8b67e3f08c0e2fd3a73683a167bfc7aae2ce390..a89c76d0bfb5c90b267f7c4e25bdb459b4e47d0d 100644 (file)
@@ -2,89 +2,22 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package http
+package http_test
 
 import (
        "fmt"
+       . "http"
+       "http/httptest"
        "io/ioutil"
-       "net"
        "os"
-       "sync"
        "testing"
 )
 
-var ParseRangeTests = []struct {
-       s      string
-       length int64
-       r      []httpRange
-}{
-       {"", 0, nil},
-       {"foo", 0, nil},
-       {"bytes=", 0, nil},
-       {"bytes=5-4", 10, nil},
-       {"bytes=0-2,5-4", 10, nil},
-       {"bytes=0-9", 10, []httpRange{{0, 10}}},
-       {"bytes=0-", 10, []httpRange{{0, 10}}},
-       {"bytes=5-", 10, []httpRange{{5, 5}}},
-       {"bytes=0-20", 10, []httpRange{{0, 10}}},
-       {"bytes=15-,0-5", 10, nil},
-       {"bytes=-5", 10, []httpRange{{5, 5}}},
-       {"bytes=-15", 10, []httpRange{{0, 10}}},
-       {"bytes=0-499", 10000, []httpRange{{0, 500}}},
-       {"bytes=500-999", 10000, []httpRange{{500, 500}}},
-       {"bytes=-500", 10000, []httpRange{{9500, 500}}},
-       {"bytes=9500-", 10000, []httpRange{{9500, 500}}},
-       {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
-       {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
-       {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
-}
-
-func TestParseRange(t *testing.T) {
-       for _, test := range ParseRangeTests {
-               r := test.r
-               ranges, err := parseRange(test.s, test.length)
-               if err != nil && r != nil {
-                       t.Errorf("parseRange(%q) returned error %q", test.s, err)
-               }
-               if len(ranges) != len(r) {
-                       t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r))
-                       continue
-               }
-               for i := range r {
-                       if ranges[i].start != r[i].start {
-                               t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start)
-                       }
-                       if ranges[i].length != r[i].length {
-                               t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length)
-                       }
-               }
-       }
-}
-
 const (
        testFile       = "testdata/file"
        testFileLength = 11
 )
 
-var (
-       serverOnce sync.Once
-       serverAddr string
-)
-
-func startServer(t *testing.T) {
-       serverOnce.Do(func() {
-               HandleFunc("/ServeFile", func(w ResponseWriter, r *Request) {
-                       ServeFile(w, r, "testdata/file")
-               })
-               l, err := net.Listen("tcp", "127.0.0.1:0")
-               if err != nil {
-                       t.Fatal("listen:", err)
-               }
-               serverAddr = l.Addr().String()
-               go Serve(l, nil)
-       })
-}
-
 var ServeFileRangeTests = []struct {
        start, end int
        r          string
@@ -99,7 +32,11 @@ var ServeFileRangeTests = []struct {
 }
 
 func TestServeFile(t *testing.T) {
-       startServer(t)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               ServeFile(w, r, "testdata/file")
+       }))
+       defer ts.Close()
+
        var err os.Error
 
        file, err := ioutil.ReadFile(testFile)
@@ -110,7 +47,7 @@ func TestServeFile(t *testing.T) {
        // set up the Request (re-used for all tests)
        var req Request
        req.Header = make(Header)
-       if req.URL, err = ParseURL("http://" + serverAddr + "/ServeFile"); err != nil {
+       if req.URL, err = ParseURL(ts.URL); err != nil {
                t.Fatal("ParseURL:", err)
        }
        req.Method = "GET"
@@ -149,7 +86,7 @@ func TestServeFile(t *testing.T) {
 }
 
 func getBody(t *testing.T, req Request) (*Response, []byte) {
-       r, err := send(&req, DefaultTransport)
+       r, err := DefaultClient.Do(&req)
        if err != nil {
                t.Fatal(req.URL.String(), "send:", err)
        }
diff --git a/libgo/go/http/httptest/recorder.go b/libgo/go/http/httptest/recorder.go
new file mode 100644 (file)
index 0000000..0dd19a6
--- /dev/null
@@ -0,0 +1,59 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The httptest package provides utilities for HTTP testing.
+package httptest
+
+import (
+       "bytes"
+       "http"
+       "os"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+       Code      int           // the HTTP response code from WriteHeader
+       HeaderMap http.Header   // the HTTP response headers
+       Body      *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
+       Flushed   bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+       return &ResponseRecorder{
+               HeaderMap: make(http.Header),
+               Body:      new(bytes.Buffer),
+       }
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header returns the response headers.
+func (rw *ResponseRecorder) Header() http.Header {
+       return rw.HeaderMap
+}
+
+// Write always succeeds and writes to rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, os.Error) {
+       if rw.Body != nil {
+               rw.Body.Write(buf)
+       }
+       if rw.Code == 0 {
+               rw.Code = http.StatusOK
+       }
+       return len(buf), nil
+}
+
+// WriteHeader sets rw.Code.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+       rw.Code = code
+}
+
+// Flush sets rw.Flushed to true.
+func (rw *ResponseRecorder) Flush() {
+       rw.Flushed = true
+}
diff --git a/libgo/go/http/httptest/server.go b/libgo/go/http/httptest/server.go
new file mode 100644 (file)
index 0000000..6e825a8
--- /dev/null
@@ -0,0 +1,70 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Implementation of Server
+
+package httptest
+
+import (
+       "fmt"
+       "http"
+       "os"
+       "net"
+)
+
+// A Server is an HTTP server listening on a system-chosen port on the
+// local loopback interface, for use in end-to-end HTTP tests.
+type Server struct {
+       URL      string // base URL of form http://ipaddr:port with no trailing slash
+       Listener net.Listener
+}
+
+// historyListener keeps track of all connections that it's ever
+// accepted.
+type historyListener struct {
+       net.Listener
+       history []net.Conn
+}
+
+func (hs *historyListener) Accept() (c net.Conn, err os.Error) {
+       c, err = hs.Listener.Accept()
+       if err == nil {
+               hs.history = append(hs.history, c)
+       }
+       return
+}
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+       ts := new(Server)
+       l, err := net.Listen("tcp", "127.0.0.1:0")
+       if err != nil {
+               if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+                       panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+               }
+       }
+       ts.Listener = &historyListener{l, make([]net.Conn, 0)}
+       ts.URL = "http://" + l.Addr().String()
+       server := &http.Server{Handler: handler}
+       go server.Serve(ts.Listener)
+       return ts
+}
+
+// Close shuts down the server.
+func (s *Server) Close() {
+       s.Listener.Close()
+}
+
+// CloseClientConnections closes any currently open HTTP connections
+// to the test Server.
+func (s *Server) CloseClientConnections() {
+       hl, ok := s.Listener.(*historyListener)
+       if !ok {
+               return
+       }
+       for _, conn := range hl.history {
+               conn.Close()
+       }
+}
index 000a4200e59a8a4d7e8730ffe2e93948585346b1..b93c5fe4855c8c5d207c2a0b9374ddb5005e68ba 100644 (file)
@@ -25,15 +25,15 @@ var (
 // i.e. requests can be read out of sync (but in the same order) while the
 // respective responses are sent.
 type ServerConn struct {
+       lk              sync.Mutex // read-write protects the following fields
        c               net.Conn
        r               *bufio.Reader
-       clsd            bool     // indicates a graceful close
        re, we          os.Error // read/write errors
        lastbody        io.ReadCloser
        nread, nwritten int
-       pipe            textproto.Pipeline
        pipereq         map[*Request]uint
-       lk              sync.Mutex // protected read/write to re,we
+
+       pipe textproto.Pipeline
 }
 
 // NewServerConn returns a new ServerConn reading and writing c.  If r is not
@@ -90,15 +90,21 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) {
                defer sc.lk.Unlock()
                return nil, sc.re
        }
+       if sc.r == nil { // connection closed by user in the meantime
+               defer sc.lk.Unlock()
+               return nil, os.EBADF
+       }
+       r := sc.r
+       lastbody := sc.lastbody
+       sc.lastbody = nil
        sc.lk.Unlock()
 
        // Make sure body is fully consumed, even if user does not call body.Close
-       if sc.lastbody != nil {
+       if lastbody != nil {
                // body.Close is assumed to be idempotent and multiple calls to
                // it should return the error that its first invokation
                // returned.
-               err = sc.lastbody.Close()
-               sc.lastbody = nil
+               err = lastbody.Close()
                if err != nil {
                        sc.lk.Lock()
                        defer sc.lk.Unlock()
@@ -107,10 +113,10 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) {
                }
        }
 
-       req, err = ReadRequest(sc.r)
+       req, err = ReadRequest(r)
+       sc.lk.Lock()
+       defer sc.lk.Unlock()
        if err != nil {
-               sc.lk.Lock()
-               defer sc.lk.Unlock()
                if err == io.ErrUnexpectedEOF {
                        // A close from the opposing client is treated as a
                        // graceful close, even if there was some unparse-able
@@ -119,18 +125,16 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) {
                        return nil, sc.re
                } else {
                        sc.re = err
-                       return
+                       return req, err
                }
        }
        sc.lastbody = req.Body
        sc.nread++
        if req.Close {
-               sc.lk.Lock()
-               defer sc.lk.Unlock()
                sc.re = ErrPersistEOF
                return req, sc.re
        }
-       return
+       return req, err
 }
 
 // Pending returns the number of unanswered requests
@@ -165,24 +169,27 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error {
                defer sc.lk.Unlock()
                return sc.we
        }
-       sc.lk.Unlock()
+       if sc.c == nil { // connection closed by user in the meantime
+               defer sc.lk.Unlock()
+               return os.EBADF
+       }
+       c := sc.c
        if sc.nread <= sc.nwritten {
+               defer sc.lk.Unlock()
                return os.NewError("persist server pipe count")
        }
-
        if resp.Close {
                // After signaling a keep-alive close, any pipelined unread
                // requests will be lost. It is up to the user to drain them
                // before signaling.
-               sc.lk.Lock()
                sc.re = ErrPersistEOF
-               sc.lk.Unlock()
        }
+       sc.lk.Unlock()
 
-       err := resp.Write(sc.c)
+       err := resp.Write(c)
+       sc.lk.Lock()
+       defer sc.lk.Unlock()
        if err != nil {
-               sc.lk.Lock()
-               defer sc.lk.Unlock()
                sc.we = err
                return err
        }
@@ -196,14 +203,17 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error {
 // responsible for closing the underlying connection. One must call Close to
 // regain control of that connection and deal with it as desired.
 type ClientConn struct {
+       lk              sync.Mutex // read-write protects the following fields
        c               net.Conn
        r               *bufio.Reader
        re, we          os.Error // read/write errors
        lastbody        io.ReadCloser
        nread, nwritten int
-       pipe            textproto.Pipeline
        pipereq         map[*Request]uint
-       lk              sync.Mutex // protects read/write to re,we,pipereq,etc.
+
+       pipe     textproto.Pipeline
+       writeReq func(*Request, io.Writer) os.Error
+       readRes  func(buf *bufio.Reader, method string) (*Response, os.Error)
 }
 
 // NewClientConn returns a new ClientConn reading and writing c.  If r is not
@@ -212,7 +222,21 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
        if r == nil {
                r = bufio.NewReader(c)
        }
-       return &ClientConn{c: c, r: r, pipereq: make(map[*Request]uint)}
+       return &ClientConn{
+               c:        c,
+               r:        r,
+               pipereq:  make(map[*Request]uint),
+               writeReq: (*Request).Write,
+               readRes:  ReadResponse,
+       }
+}
+
+// NewProxyClientConn works like NewClientConn but writes Requests
+// using Request's WriteProxy method.
+func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+       cc := NewClientConn(c, r)
+       cc.writeReq = (*Request).WriteProxy
+       return cc
 }
 
 // Close detaches the ClientConn and returns the underlying connection as well
@@ -221,11 +245,11 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
 // logic. The user should not call Close while Read or Write is in progress.
 func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) {
        cc.lk.Lock()
+       defer cc.lk.Unlock()
        c = cc.c
        r = cc.r
        cc.c = nil
        cc.r = nil
-       cc.lk.Unlock()
        return
 }
 
@@ -261,20 +285,22 @@ func (cc *ClientConn) Write(req *Request) (err os.Error) {
                defer cc.lk.Unlock()
                return cc.we
        }
-       cc.lk.Unlock()
-
+       if cc.c == nil { // connection closed by user in the meantime
+               defer cc.lk.Unlock()
+               return os.EBADF
+       }
+       c := cc.c
        if req.Close {
                // We write the EOF to the write-side error, because there
                // still might be some pipelined reads
-               cc.lk.Lock()
                cc.we = ErrPersistEOF
-               cc.lk.Unlock()
        }
+       cc.lk.Unlock()
 
-       err = req.Write(cc.c)
+       err = cc.writeReq(req, c)
+       cc.lk.Lock()
+       defer cc.lk.Unlock()
        if err != nil {
-               cc.lk.Lock()
-               defer cc.lk.Unlock()
                cc.we = err
                return err
        }
@@ -316,15 +342,21 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) {
                defer cc.lk.Unlock()
                return nil, cc.re
        }
+       if cc.r == nil { // connection closed by user in the meantime
+               defer cc.lk.Unlock()
+               return nil, os.EBADF
+       }
+       r := cc.r
+       lastbody := cc.lastbody
+       cc.lastbody = nil
        cc.lk.Unlock()
 
        // Make sure body is fully consumed, even if user does not call body.Close
-       if cc.lastbody != nil {
+       if lastbody != nil {
                // body.Close is assumed to be idempotent and multiple calls to
                // it should return the error that its first invokation
                // returned.
-               err = cc.lastbody.Close()
-               cc.lastbody = nil
+               err = lastbody.Close()
                if err != nil {
                        cc.lk.Lock()
                        defer cc.lk.Unlock()
@@ -333,24 +365,22 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) {
                }
        }
 
-       resp, err = ReadResponse(cc.r, req.Method)
+       resp, err = cc.readRes(r, req.Method)
+       cc.lk.Lock()
+       defer cc.lk.Unlock()
        if err != nil {
-               cc.lk.Lock()
-               defer cc.lk.Unlock()
                cc.re = err
-               return
+               return resp, err
        }
        cc.lastbody = resp.Body
 
        cc.nread++
 
        if resp.Close {
-               cc.lk.Lock()
-               defer cc.lk.Unlock()
                cc.re = ErrPersistEOF // don't send any more requests
                return resp, cc.re
        }
-       return
+       return resp, err
 }
 
 // Do is convenience method that writes a request and reads a response.
index f7db9aab93bf220d916c83e27f943a63fa1ef5fc..0bac26687d73d842b99793d5b75c142d8ee4a053 100644 (file)
@@ -41,14 +41,14 @@ func init() {
 // command line, with arguments separated by NUL bytes.
 // The package initialization registers it as /debug/pprof/cmdline.
 func Cmdline(w http.ResponseWriter, r *http.Request) {
-       w.SetHeader("content-type", "text/plain; charset=utf-8")
+       w.Header().Set("content-type", "text/plain; charset=utf-8")
        fmt.Fprintf(w, strings.Join(os.Args, "\x00"))
 }
 
 // Heap responds with the pprof-formatted heap profile.
 // The package initialization registers it as /debug/pprof/heap.
 func Heap(w http.ResponseWriter, r *http.Request) {
-       w.SetHeader("content-type", "text/plain; charset=utf-8")
+       w.Header().Set("content-type", "text/plain; charset=utf-8")
        pprof.WriteHeapProfile(w)
 }
 
@@ -56,7 +56,7 @@ func Heap(w http.ResponseWriter, r *http.Request) {
 // responding with a table mapping program counters to function names.
 // The package initialization registers it as /debug/pprof/symbol.
 func Symbol(w http.ResponseWriter, r *http.Request) {
-       w.SetHeader("content-type", "text/plain; charset=utf-8")
+       w.Header().Set("content-type", "text/plain; charset=utf-8")
 
        // We don't know how many symbols we have, but we
        // do have symbol information.  Pprof only cares whether
index 0f2ca458fed2c53d03781a95f462875da9384f28..7050ef5ed063d96fc15fe957c72a2020580668a7 100644 (file)
@@ -12,31 +12,33 @@ import (
 // TODO(mattn):
 //     test ProxyAuth
 
-var MatchNoProxyTests = []struct {
+var UseProxyTests = []struct {
        host  string
        match bool
 }{
-       {"localhost", true},        // match completely
-       {"barbaz.net", true},       // match as .barbaz.net
-       {"foobar.com:443", true},   // have a port but match 
-       {"foofoobar.com", false},   // not match as a part of foobar.com
-       {"baz.com", false},         // not match as a part of barbaz.com
-       {"localhost.net", false},   // not match as suffix of address
-       {"local.localhost", false}, // not match as prefix as address
-       {"barbarbaz.net", false},   // not match because NO_PROXY have a '.'
-       {"www.foobar.com", false},  // not match because NO_PROXY is not .foobar.com
+       {"localhost", false},      // match completely
+       {"barbaz.net", false},     // match as .barbaz.net
+       {"foobar.com:443", false}, // have a port but match 
+       {"foofoobar.com", true},   // not match as a part of foobar.com
+       {"baz.com", true},         // not match as a part of barbaz.com
+       {"localhost.net", true},   // not match as suffix of address
+       {"local.localhost", true}, // not match as prefix as address
+       {"barbarbaz.net", true},   // not match because NO_PROXY have a '.'
+       {"www.foobar.com", true},  // not match because NO_PROXY is not .foobar.com
 }
 
-func TestMatchNoProxy(t *testing.T) {
+func TestUseProxy(t *testing.T) {
        oldenv := os.Getenv("NO_PROXY")
        no_proxy := "foobar.com, .barbaz.net   , localhost"
        os.Setenv("NO_PROXY", no_proxy)
        defer os.Setenv("NO_PROXY", oldenv)
 
-       for _, test := range MatchNoProxyTests {
-               if matchNoProxy(test.host) != test.match {
+       tr := &Transport{}
+
+       for _, test := range UseProxyTests {
+               if tr.useProxy(test.host) != test.match {
                        if test.match {
-                               t.Errorf("matchNoProxy(%v) = %v, want %v", test.host, !test.match, test.match)
+                               t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
                        } else {
                                t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy)
                        }
diff --git a/libgo/go/http/range_test.go b/libgo/go/http/range_test.go
new file mode 100644 (file)
index 0000000..5274a81
--- /dev/null
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "testing"
+)
+
+var ParseRangeTests = []struct {
+       s      string
+       length int64
+       r      []httpRange
+}{
+       {"", 0, nil},
+       {"foo", 0, nil},
+       {"bytes=", 0, nil},
+       {"bytes=5-4", 10, nil},
+       {"bytes=0-2,5-4", 10, nil},
+       {"bytes=0-9", 10, []httpRange{{0, 10}}},
+       {"bytes=0-", 10, []httpRange{{0, 10}}},
+       {"bytes=5-", 10, []httpRange{{5, 5}}},
+       {"bytes=0-20", 10, []httpRange{{0, 10}}},
+       {"bytes=15-,0-5", 10, nil},
+       {"bytes=-5", 10, []httpRange{{5, 5}}},
+       {"bytes=-15", 10, []httpRange{{0, 10}}},
+       {"bytes=0-499", 10000, []httpRange{{0, 500}}},
+       {"bytes=500-999", 10000, []httpRange{{500, 500}}},
+       {"bytes=-500", 10000, []httpRange{{9500, 500}}},
+       {"bytes=9500-", 10000, []httpRange{{9500, 500}}},
+       {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
+       {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
+       {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+}
+
+func TestParseRange(t *testing.T) {
+       for _, test := range ParseRangeTests {
+               r := test.r
+               ranges, err := parseRange(test.s, test.length)
+               if err != nil && r != nil {
+                       t.Errorf("parseRange(%q) returned error %q", test.s, err)
+               }
+               if len(ranges) != len(r) {
+                       t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r))
+                       continue
+               }
+               for i := range r {
+                       if ranges[i].start != r[i].start {
+                               t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start)
+                       }
+                       if ranges[i].length != r[i].length {
+                               t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length)
+                       }
+               }
+       }
+}
index 6ee07bc9148a2687894d036ff7423ce96123de5f..19e2ff77476a97581c0f7f151b0af4cef8a750d8 100644 (file)
@@ -93,7 +93,7 @@ var reqTests = []reqTest{
                        Proto:         "HTTP/1.1",
                        ProtoMajor:    1,
                        ProtoMinor:    1,
-                       Header:        map[string][]string{},
+                       Header:        Header{},
                        Close:         false,
                        ContentLength: -1,
                        Host:          "test",
index a7dc328a0075c75f1eb17fa58fa5c491c7170454..d82894fab08829716e6411d5dba7ee46be574eb5 100644 (file)
@@ -11,6 +11,7 @@ package http
 
 import (
        "bufio"
+       "crypto/tls"
        "container/vector"
        "fmt"
        "io"
@@ -92,6 +93,9 @@ type Request struct {
        // following a hyphen uppercase and the rest lowercase.
        Header Header
 
+       // Cookie records the HTTP cookies sent with the request.
+       Cookie []*Cookie
+
        // The message body.
        Body io.ReadCloser
 
@@ -134,6 +138,22 @@ type Request struct {
        // response has multiple trailer lines with the same key, they will be
        // concatenated, delimited by commas.
        Trailer Header
+
+       // RemoteAddr allows HTTP servers and other software to record
+       // the network address that sent the request, usually for
+       // logging. This field is not filled in by ReadRequest and
+       // has no defined format. The HTTP server in this package
+       // sets RemoteAddr to an "IP:port" address before invoking a
+       // handler.
+       RemoteAddr string
+
+       // TLS allows HTTP servers and other software to record
+       // information about the TLS connection on which the request
+       // was received. This field is not filled in by ReadRequest.
+       // The HTTP server in this package sets the field for
+       // TLS-enabled connections before invoking a handler;
+       // otherwise it leaves the field nil.
+       TLS *tls.ConnectionState
 }
 
 // ProtoAtLeast returns whether the HTTP protocol used
@@ -190,6 +210,8 @@ func (req *Request) Write(w io.Writer) os.Error {
 // WriteProxy is like Write but writes the request in the form
 // expected by an HTTP proxy.  It includes the scheme and host
 // name in the URI instead of using a separate Host: header line.
+// If req.RawURL is non-empty, WriteProxy uses it unchanged
+// instead of URL but still omits the Host: header.
 func (req *Request) WriteProxy(w io.Writer) os.Error {
        return req.write(w, true)
 }
@@ -206,13 +228,12 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
                if req.URL.RawQuery != "" {
                        uri += "?" + req.URL.RawQuery
                }
-       }
-
-       if usingProxy {
-               if uri == "" || uri[0] != '/' {
-                       uri = "/" + uri
+               if usingProxy {
+                       if uri == "" || uri[0] != '/' {
+                               uri = "/" + uri
+                       }
+                       uri = req.URL.Scheme + "://" + host + uri
                }
-               uri = req.URL.Scheme + "://" + host + uri
        }
 
        fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), uri)
@@ -243,11 +264,15 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
        // from Request, and introduce Request methods along the lines of
        // Response.{GetHeader,AddHeader} and string constants for "Host",
        // "User-Agent" and "Referer".
-       err = writeSortedKeyValue(w, req.Header, reqExcludeHeader)
+       err = writeSortedHeader(w, req.Header, reqExcludeHeader)
        if err != nil {
                return err
        }
 
+       if err = writeCookies(w, req.Cookie); err != nil {
+               return err
+       }
+
        io.WriteString(w, "\r\n")
 
        // Write body and trailer
@@ -484,6 +509,8 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) {
                return nil, err
        }
 
+       req.Cookie = readCookies(req.Header)
+
        return req, nil
 }
 
index ae1c4e98245212f9d48d1d2c28f843d99405966c..19083adf624ee72bb209922ca9e63ed1d833969f 100644 (file)
@@ -2,10 +2,15 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package http
+package http_test
 
 import (
        "bytes"
+       "fmt"
+       . "http"
+       "http/httptest"
+       "io"
+       "os"
        "reflect"
        "regexp"
        "strings"
@@ -141,17 +146,33 @@ func TestMultipartReader(t *testing.T) {
 }
 
 func TestRedirect(t *testing.T) {
-       const (
-               start = "http://google.com/"
-               endRe = "^http://www\\.google\\.[a-z.]+/$"
-       )
-       var end = regexp.MustCompile(endRe)
-       r, url, err := Get(start)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               switch r.URL.Path {
+               case "/":
+                       w.Header().Set("Location", "/foo/")
+                       w.WriteHeader(StatusSeeOther)
+               case "/foo/":
+                       fmt.Fprintf(w, "foo")
+               default:
+                       w.WriteHeader(StatusBadRequest)
+               }
+       }))
+       defer ts.Close()
+
+       var end = regexp.MustCompile("/foo/$")
+       r, url, err := Get(ts.URL)
        if err != nil {
                t.Fatal(err)
        }
        r.Body.Close()
        if r.StatusCode != 200 || !end.MatchString(url) {
-               t.Fatalf("Get(%s) got status %d at %q, want 200 matching %q", start, r.StatusCode, url, endRe)
+               t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url)
        }
 }
+
+// TODO: stop copy/pasting this around.  move to io/ioutil?
+type nopCloser struct {
+       io.Reader
+}
+
+func (nopCloser) Close() os.Error { return nil }
index 55ca745d58c34fddc84478c311ccd0bdd8619cff..726baa2668633f28c4daddd7ed099cd144094fab 100644 (file)
@@ -6,12 +6,15 @@ package http
 
 import (
        "bytes"
+       "io/ioutil"
        "testing"
 )
 
 type reqWriteTest struct {
-       Req Request
-       Raw string
+       Req      Request
+       Body     []byte
+       Raw      string
+       RawProxy string
 }
 
 var reqWriteTests = []reqWriteTest{
@@ -50,6 +53,8 @@ var reqWriteTests = []reqWriteTest{
                        Form:      map[string][]string{},
                },
 
+               nil,
+
                "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
                        "Host: www.techcrunch.com\r\n" +
                        "User-Agent: Fake\r\n" +
@@ -59,6 +64,15 @@ var reqWriteTests = []reqWriteTest{
                        "Accept-Language: en-us,en;q=0.5\r\n" +
                        "Keep-Alive: 300\r\n" +
                        "Proxy-Connection: keep-alive\r\n\r\n",
+
+               "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+                       "User-Agent: Fake\r\n" +
+                       "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+                       "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+                       "Accept-Encoding: gzip,deflate\r\n" +
+                       "Accept-Language: en-us,en;q=0.5\r\n" +
+                       "Keep-Alive: 300\r\n" +
+                       "Proxy-Connection: keep-alive\r\n\r\n",
        },
        // HTTP/1.1 => chunked coding; body; empty trailer
        {
@@ -71,16 +85,22 @@ var reqWriteTests = []reqWriteTest{
                        },
                        ProtoMajor:       1,
                        ProtoMinor:       1,
-                       Header:           map[string][]string{},
-                       Body:             nopCloser{bytes.NewBufferString("abcdef")},
+                       Header:           Header{},
                        TransferEncoding: []string{"chunked"},
                },
 
+               []byte("abcdef"),
+
                "GET /search HTTP/1.1\r\n" +
                        "Host: www.google.com\r\n" +
                        "User-Agent: Go http package\r\n" +
                        "Transfer-Encoding: chunked\r\n\r\n" +
                        "6\r\nabcdef\r\n0\r\n\r\n",
+
+               "GET http://www.google.com/search HTTP/1.1\r\n" +
+                       "User-Agent: Go http package\r\n" +
+                       "Transfer-Encoding: chunked\r\n\r\n" +
+                       "6\r\nabcdef\r\n0\r\n\r\n",
        },
        // HTTP/1.1 POST => chunked coding; body; empty trailer
        {
@@ -93,18 +113,25 @@ var reqWriteTests = []reqWriteTest{
                        },
                        ProtoMajor:       1,
                        ProtoMinor:       1,
-                       Header:           map[string][]string{},
+                       Header:           Header{},
                        Close:            true,
-                       Body:             nopCloser{bytes.NewBufferString("abcdef")},
                        TransferEncoding: []string{"chunked"},
                },
 
+               []byte("abcdef"),
+
                "POST /search HTTP/1.1\r\n" +
                        "Host: www.google.com\r\n" +
                        "User-Agent: Go http package\r\n" +
                        "Connection: close\r\n" +
                        "Transfer-Encoding: chunked\r\n\r\n" +
                        "6\r\nabcdef\r\n0\r\n\r\n",
+
+               "POST http://www.google.com/search HTTP/1.1\r\n" +
+                       "User-Agent: Go http package\r\n" +
+                       "Connection: close\r\n" +
+                       "Transfer-Encoding: chunked\r\n\r\n" +
+                       "6\r\nabcdef\r\n0\r\n\r\n",
        },
        // default to HTTP/1.1
        {
@@ -114,16 +141,26 @@ var reqWriteTests = []reqWriteTest{
                        Host:   "www.google.com",
                },
 
+               nil,
+
                "GET /search HTTP/1.1\r\n" +
                        "Host: www.google.com\r\n" +
                        "User-Agent: Go http package\r\n" +
                        "\r\n",
+
+               // Looks weird but RawURL overrides what WriteProxy would choose.
+               "GET /search HTTP/1.1\r\n" +
+                       "User-Agent: Go http package\r\n" +
+                       "\r\n",
        },
 }
 
 func TestRequestWrite(t *testing.T) {
        for i := range reqWriteTests {
                tt := &reqWriteTests[i]
+               if tt.Body != nil {
+                       tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(tt.Body))
+               }
                var braw bytes.Buffer
                err := tt.Req.Write(&braw)
                if err != nil {
@@ -135,5 +172,20 @@ func TestRequestWrite(t *testing.T) {
                        t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, sraw)
                        continue
                }
+
+               if tt.Body != nil {
+                       tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(tt.Body))
+               }
+               var praw bytes.Buffer
+               err = tt.Req.WriteProxy(&praw)
+               if err != nil {
+                       t.Errorf("error writing #%d: %s", i, err)
+                       continue
+               }
+               sraw = praw.String()
+               if sraw != tt.RawProxy {
+                       t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.RawProxy, sraw)
+                       continue
+               }
        }
 }
index 3f919c86a3cb775baa6f051439a5d04376594c64..1f725ecdddd0b786ebe65413f386e2137c77d377 100644 (file)
@@ -46,6 +46,9 @@ type Response struct {
        // Keys in the map are canonicalized (see CanonicalHeaderKey).
        Header Header
 
+       // SetCookie records the Set-Cookie requests sent with the response.
+       SetCookie []*Cookie
+
        // Body represents the response body.
        Body io.ReadCloser
 
@@ -64,10 +67,9 @@ type Response struct {
        // ReadResponse nor Response.Write ever closes a connection.
        Close bool
 
-       // Trailer maps trailer keys to values.  Like for Header, if the
-       // response has multiple trailer lines with the same key, they will be
-       // concatenated, delimited by commas.
-       Trailer map[string][]string
+       // Trailer maps trailer keys to values, in the same
+       // format as the header.
+       Trailer Header
 }
 
 // ReadResponse reads and returns an HTTP response from r.  The RequestMethod
@@ -124,6 +126,8 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os
                return nil, err
        }
 
+       resp.SetCookie = readSetCookies(resp.Header)
+
        return resp, nil
 }
 
@@ -188,11 +192,15 @@ func (resp *Response) Write(w io.Writer) os.Error {
        }
 
        // Rest of header
-       err = writeSortedKeyValue(w, resp.Header, respExcludeHeader)
+       err = writeSortedHeader(w, resp.Header, respExcludeHeader)
        if err != nil {
                return err
        }
 
+       if err = writeSetCookies(w, resp.SetCookie); err != nil {
+               return err
+       }
+
        // End-of-header
        io.WriteString(w, "\r\n")
 
@@ -206,16 +214,22 @@ func (resp *Response) Write(w io.Writer) os.Error {
        return nil
 }
 
-func writeSortedKeyValue(w io.Writer, kvm map[string][]string, exclude map[string]bool) os.Error {
-       keys := make([]string, 0, len(kvm))
-       for k := range kvm {
-               if !exclude[k] {
+func writeSortedHeader(w io.Writer, h Header, exclude map[string]bool) os.Error {
+       keys := make([]string, 0, len(h))
+       for k := range h {
+               if exclude == nil || !exclude[k] {
                        keys = append(keys, k)
                }
        }
        sort.SortStrings(keys)
        for _, k := range keys {
-               for _, v := range kvm[k] {
+               for _, v := range h[k] {
+                       v = strings.Replace(v, "\n", " ", -1)
+                       v = strings.Replace(v, "\r", " ", -1)
+                       v = strings.TrimSpace(v)
+                       if v == "" {
+                               continue
+                       }
                        if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil {
                                return err
                        }
index aabb833f9c867123dca2cab03195a04b1b16cd06..de0635da516c8ec5a7d2c3572581c5ccf587f17f 100644 (file)
@@ -6,6 +6,7 @@ package http
 
 import (
        "bytes"
+       "io/ioutil"
        "testing"
 )
 
@@ -22,8 +23,8 @@ var respWriteTests = []respWriteTest{
                        ProtoMajor:    1,
                        ProtoMinor:    0,
                        RequestMethod: "GET",
-                       Header:        map[string][]string{},
-                       Body:          nopCloser{bytes.NewBufferString("abcdef")},
+                       Header:        Header{},
+                       Body:          ioutil.NopCloser(bytes.NewBufferString("abcdef")),
                        ContentLength: 6,
                },
 
@@ -38,8 +39,8 @@ var respWriteTests = []respWriteTest{
                        ProtoMajor:    1,
                        ProtoMinor:    0,
                        RequestMethod: "GET",
-                       Header:        map[string][]string{},
-                       Body:          nopCloser{bytes.NewBufferString("abcdef")},
+                       Header:        Header{},
+                       Body:          ioutil.NopCloser(bytes.NewBufferString("abcdef")),
                        ContentLength: -1,
                },
                "HTTP/1.0 200 OK\r\n" +
@@ -53,8 +54,8 @@ var respWriteTests = []respWriteTest{
                        ProtoMajor:       1,
                        ProtoMinor:       1,
                        RequestMethod:    "GET",
-                       Header:           map[string][]string{},
-                       Body:             nopCloser{bytes.NewBufferString("abcdef")},
+                       Header:           Header{},
+                       Body:             ioutil.NopCloser(bytes.NewBufferString("abcdef")),
                        ContentLength:    6,
                        TransferEncoding: []string{"chunked"},
                        Close:            true,
@@ -65,6 +66,29 @@ var respWriteTests = []respWriteTest{
                        "Transfer-Encoding: chunked\r\n\r\n" +
                        "6\r\nabcdef\r\n0\r\n\r\n",
        },
+
+       // Header value with a newline character (Issue 914).
+       // Also tests removal of leading and trailing whitespace.
+       {
+               Response{
+                       StatusCode:    204,
+                       ProtoMajor:    1,
+                       ProtoMinor:    1,
+                       RequestMethod: "GET",
+                       Header: Header{
+                               "Foo": []string{" Bar\nBaz "},
+                       },
+                       Body:             nil,
+                       ContentLength:    0,
+                       TransferEncoding: []string{"chunked"},
+                       Close:            true,
+               },
+
+               "HTTP/1.1 204 No Content\r\n" +
+                       "Connection: close\r\n" +
+                       "Foo: Bar Baz\r\n" +
+                       "\r\n",
+       },
 }
 
 func TestResponseWrite(t *testing.T) {
@@ -78,7 +102,7 @@ func TestResponseWrite(t *testing.T) {
                }
                sraw := braw.String()
                if sraw != tt.Raw {
-                       t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, sraw)
+                       t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw)
                        continue
                }
        }
index 42fe3e5e4d2e9f0420082efc315df347d957ac61..683de85b8670b2134afa52f8551fae1c1dd248d3 100644 (file)
@@ -4,16 +4,18 @@
 
 // End-to-end serving tests
 
-package http
+package http_test
 
 import (
        "bufio"
        "bytes"
        "fmt"
-       "io"
+       . "http"
+       "http/httptest"
        "io/ioutil"
        "os"
        "net"
+       "reflect"
        "strings"
        "testing"
        "time"
@@ -143,7 +145,7 @@ func TestConsumingBodyOnNextConn(t *testing.T) {
 type stringHandler string
 
 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
-       w.SetHeader("Result", string(s))
+       w.Header().Set("Result", string(s))
 }
 
 var handlers = []struct {
@@ -170,13 +172,10 @@ func TestHostHandlers(t *testing.T) {
        for _, h := range handlers {
                Handle(h.pattern, stringHandler(h.msg))
        }
-       l, err := net.Listen("tcp", "127.0.0.1:0") // any port
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer l.Close()
-       go Serve(l, nil)
-       conn, err := net.Dial("tcp", "", l.Addr().String())
+       ts := httptest.NewServer(nil)
+       defer ts.Close()
+
+       conn, err := net.Dial("tcp", "", ts.Listener.Addr().String())
        if err != nil {
                t.Fatal(err)
        }
@@ -205,46 +204,6 @@ func TestHostHandlers(t *testing.T) {
        }
 }
 
-type responseWriterMethodCall struct {
-       method                 string
-       headerKey, headerValue string // if method == "SetHeader"
-       bytesWritten           []byte // if method == "Write"
-       responseCode           int    // if method == "WriteHeader"
-}
-
-type recordingResponseWriter struct {
-       log []*responseWriterMethodCall
-}
-
-func (rw *recordingResponseWriter) RemoteAddr() string {
-       return "1.2.3.4"
-}
-
-func (rw *recordingResponseWriter) UsingTLS() bool {
-       return false
-}
-
-func (rw *recordingResponseWriter) SetHeader(k, v string) {
-       rw.log = append(rw.log, &responseWriterMethodCall{method: "SetHeader", headerKey: k, headerValue: v})
-}
-
-func (rw *recordingResponseWriter) Write(buf []byte) (int, os.Error) {
-       rw.log = append(rw.log, &responseWriterMethodCall{method: "Write", bytesWritten: buf})
-       return len(buf), nil
-}
-
-func (rw *recordingResponseWriter) WriteHeader(code int) {
-       rw.log = append(rw.log, &responseWriterMethodCall{method: "WriteHeader", responseCode: code})
-}
-
-func (rw *recordingResponseWriter) Flush() {
-       rw.log = append(rw.log, &responseWriterMethodCall{method: "Flush"})
-}
-
-func (rw *recordingResponseWriter) Hijack() (io.ReadWriteCloser, *bufio.ReadWriter, os.Error) {
-       panic("Not supported")
-}
-
 // Tests for http://code.google.com/p/go/issues/detail?id=900
 func TestMuxRedirectLeadingSlashes(t *testing.T) {
        paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
@@ -254,41 +213,24 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
                        t.Errorf("%s", err)
                }
                mux := NewServeMux()
-               resp := new(recordingResponseWriter)
-               resp.log = make([]*responseWriterMethodCall, 0)
+               resp := httptest.NewRecorder()
 
                mux.ServeHTTP(resp, req)
 
-               dumpLog := func() {
-                       t.Logf("For path %q:", path)
-                       for _, call := range resp.log {
-                               t.Logf("Got call: %s, header=%s, value=%s, buf=%q, code=%d", call.method,
-                                       call.headerKey, call.headerValue, call.bytesWritten, call.responseCode)
-                       }
-               }
-
-               if len(resp.log) != 2 {
-                       dumpLog()
-                       t.Errorf("expected 2 calls to response writer; got %d", len(resp.log))
-                       return
-               }
-
-               if resp.log[0].method != "SetHeader" ||
-                       resp.log[0].headerKey != "Location" || resp.log[0].headerValue != "/foo.txt" {
-                       dumpLog()
-                       t.Errorf("Expected SetHeader of Location to /foo.txt")
+               if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
+                       t.Errorf("Expected Location header set to %q; got %q", expected, loc)
                        return
                }
 
-               if resp.log[1].method != "WriteHeader" || resp.log[1].responseCode != StatusMovedPermanently {
-                       dumpLog()
-                       t.Errorf("Expected WriteHeader of StatusMovedPermanently")
+               if code, expected := resp.Code, StatusMovedPermanently; code != expected {
+                       t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
                        return
                }
        }
 }
 
 func TestServerTimeouts(t *testing.T) {
+       // TODO(bradfitz): convert this to use httptest.Server
        l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0})
        if err != nil {
                t.Fatalf("listen error: %v", err)
@@ -308,7 +250,9 @@ func TestServerTimeouts(t *testing.T) {
        url := fmt.Sprintf("http://localhost:%d/", addr.Port)
 
        // Hit the HTTP server successfully.
-       r, _, err := Get(url)
+       tr := &Transport{DisableKeepAlives: true} // they interfere with this test
+       c := &Client{Transport: tr}
+       r, _, err := c.Get(url)
        if err != nil {
                t.Fatalf("http Get #1: %v", err)
        }
@@ -353,16 +297,9 @@ func TestServerTimeouts(t *testing.T) {
 
 // TestIdentityResponse verifies that a handler can unset 
 func TestIdentityResponse(t *testing.T) {
-       l, err := net.Listen("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatalf("failed to listen on a port: %v", err)
-       }
-       defer l.Close()
-       urlBase := "http://" + l.Addr().String() + "/"
-
        handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
-               rw.SetHeader("Content-Length", "3")
-               rw.SetHeader("Transfer-Encoding", req.FormValue("te"))
+               rw.Header().Set("Content-Length", "3")
+               rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
                switch {
                case req.FormValue("overwrite") == "1":
                        _, err := rw.Write([]byte("foo TOO LONG"))
@@ -370,22 +307,22 @@ func TestIdentityResponse(t *testing.T) {
                                t.Errorf("expected ErrContentLength; got %v", err)
                        }
                case req.FormValue("underwrite") == "1":
-                       rw.SetHeader("Content-Length", "500")
+                       rw.Header().Set("Content-Length", "500")
                        rw.Write([]byte("too short"))
                default:
                        rw.Write([]byte("foo"))
                }
        })
 
-       server := &Server{Handler: handler}
-       go server.Serve(l)
+       ts := httptest.NewServer(handler)
+       defer ts.Close()
 
        // Note: this relies on the assumption (which is true) that
        // Get sends HTTP/1.1 or greater requests.  Otherwise the
        // server wouldn't have the choice to send back chunked
        // responses.
        for _, te := range []string{"", "identity"} {
-               url := urlBase + "?te=" + te
+               url := ts.URL + "/?te=" + te
                res, _, err := Get(url)
                if err != nil {
                        t.Fatalf("error with Get of %s: %v", url, err)
@@ -400,18 +337,18 @@ func TestIdentityResponse(t *testing.T) {
                        t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
                                url, expected, tl, res.TransferEncoding)
                }
+               res.Body.Close()
        }
 
        // Verify that ErrContentLength is returned
-       url := urlBase + "?overwrite=1"
-       _, _, err = Get(url)
+       url := ts.URL + "/?overwrite=1"
+       _, _, err := Get(url)
        if err != nil {
                t.Fatalf("error with Get of %s: %v", url, err)
        }
-
        // Verify that the connection is closed when the declared Content-Length
        // is larger than what the handler wrote.
-       conn, err := net.Dial("tcp", "", l.Addr().String())
+       conn, err := net.Dial("tcp", "", ts.Listener.Addr().String())
        if err != nil {
                t.Fatalf("error dialing: %v", err)
        }
@@ -432,3 +369,141 @@ func TestIdentityResponse(t *testing.T) {
                        expectedSuffix, string(got))
        }
 }
+
+// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
+func TestServeHTTP10Close(t *testing.T) {
+       s := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               ServeFile(w, r, "testdata/file")
+       }))
+       defer s.Close()
+
+       conn, err := net.Dial("tcp", "", s.Listener.Addr().String())
+       if err != nil {
+               t.Fatal("dial error:", err)
+       }
+       defer conn.Close()
+
+       _, err = fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n")
+       if err != nil {
+               t.Fatal("print error:", err)
+       }
+
+       r := bufio.NewReader(conn)
+       _, err = ReadResponse(r, "GET")
+       if err != nil {
+               t.Fatal("ReadResponse error:", err)
+       }
+
+       success := make(chan bool)
+       go func() {
+               select {
+               case <-time.After(5e9):
+                       t.Fatal("body not closed after 5s")
+               case <-success:
+               }
+       }()
+
+       _, err = ioutil.ReadAll(r)
+       if err != nil {
+               t.Fatal("read error:", err)
+       }
+
+       success <- true
+}
+
+func TestSetsRemoteAddr(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               fmt.Fprintf(w, "%s", r.RemoteAddr)
+       }))
+       defer ts.Close()
+
+       res, _, err := Get(ts.URL)
+       if err != nil {
+               t.Fatalf("Get error: %v", err)
+       }
+       body, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatalf("ReadAll error: %v", err)
+       }
+       ip := string(body)
+       if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
+               t.Fatalf("Expected local addr; got %q", ip)
+       }
+}
+
+func TestChunkedResponseHeaders(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+               fmt.Fprintf(w, "I am a chunked response.")
+       }))
+       defer ts.Close()
+
+       res, _, err := Get(ts.URL)
+       if err != nil {
+               t.Fatalf("Get error: %v", err)
+       }
+       if g, e := res.ContentLength, int64(-1); g != e {
+               t.Errorf("expected ContentLength of %d; got %d", e, g)
+       }
+       if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) {
+               t.Errorf("expected TransferEncoding of %v; got %v", e, g)
+       }
+       if _, haveCL := res.Header["Content-Length"]; haveCL {
+               t.Errorf("Unexpected Content-Length")
+       }
+}
+
+// Test304Responses verifies that 304s don't declare that they're
+// chunking in their response headers and aren't allowed to produce
+// output.
+func Test304Responses(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.WriteHeader(StatusNotModified)
+               _, err := w.Write([]byte("illegal body"))
+               if err != ErrBodyNotAllowed {
+                       t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+               }
+       }))
+       defer ts.Close()
+       res, _, err := Get(ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if len(res.TransferEncoding) > 0 {
+               t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+       }
+       body, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Error(err)
+       }
+       if len(body) > 0 {
+               t.Errorf("got unexpected body %q", string(body))
+       }
+}
+
+// TestHeadResponses verifies that responses to HEAD requests don't
+// declare that they're chunking in their response headers and aren't
+// allowed to produce output.
+func TestHeadResponses(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               _, err := w.Write([]byte("Ignored body"))
+               if err != ErrBodyNotAllowed {
+                       t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+               }
+       }))
+       defer ts.Close()
+       res, err := Head(ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if len(res.TransferEncoding) > 0 {
+               t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+       }
+       body, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Error(err)
+       }
+       if len(body) > 0 {
+               t.Errorf("got unexpected body %q", string(body))
+       }
+}
index 977c8c2297a10b09dbdc0b041efb3664860b4f23..8e7039371ae1d011dcfafa912020c0c7eeac3bd0 100644 (file)
@@ -6,7 +6,6 @@
 
 // TODO(rsc):
 //     logging
-//     cgi support
 //     post support
 
 package http
@@ -49,23 +48,10 @@ type Handler interface {
 // A ResponseWriter interface is used by an HTTP handler to
 // construct an HTTP response.
 type ResponseWriter interface {
-       // RemoteAddr returns the address of the client that sent the current request
-       RemoteAddr() string
-
-       // UsingTLS returns true if the client is connected using TLS
-       UsingTLS() bool
-
-       // SetHeader sets a header line in the eventual response.
-       // For example, SetHeader("Content-Type", "text/html; charset=utf-8")
-       // will result in the header line
-       //
-       //      Content-Type: text/html; charset=utf-8
-       //
-       // being sent. UTF-8 encoded HTML is the default setting for
-       // Content-Type in this library, so users need not make that
-       // particular call. Calls to SetHeader after WriteHeader (or Write)
-       // are ignored. An empty value removes the header if previously set.
-       SetHeader(string, string)
+       // Header returns the header map that will be sent by WriteHeader.
+       // Changing the header after a call to WriteHeader (or Write) has
+       // no effect.
+       Header() Header
 
        // Write writes the data to the connection as part of an HTTP reply.
        // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK)
@@ -78,39 +64,52 @@ type ResponseWriter interface {
        // Thus explicit calls to WriteHeader are mainly used to
        // send error codes.
        WriteHeader(int)
+}
 
+// The Flusher interface is implemented by ResponseWriters that allow
+// an HTTP handler to flush buffered data to the client.
+//
+// Note that even for ResponseWriters that support Flush,
+// if the client is connected through an HTTP proxy,
+// the buffered data may not reach the client until the response
+// completes.
+type Flusher interface {
        // Flush sends any buffered data to the client.
        Flush()
+}
 
+// The Hijacker interface is implemented by ResponseWriters that allow
+// an HTTP handler to take over the connection.
+type Hijacker interface {
        // Hijack lets the caller take over the connection.
        // After a call to Hijack(), the HTTP server library
        // will not do anything else with the connection.
        // It becomes the caller's responsibility to manage
        // and close the connection.
-       Hijack() (io.ReadWriteCloser, *bufio.ReadWriter, os.Error)
+       Hijack() (net.Conn, *bufio.ReadWriter, os.Error)
 }
 
 // A conn represents the server side of an HTTP connection.
 type conn struct {
-       remoteAddr string             // network address of remote side
-       handler    Handler            // request handler
-       rwc        io.ReadWriteCloser // i/o connection
-       buf        *bufio.ReadWriter  // buffered rwc
-       hijacked   bool               // connection has been hijacked by handler
-       usingTLS   bool               // a flag indicating connection over TLS
+       remoteAddr string               // network address of remote side
+       handler    Handler              // request handler
+       rwc        net.Conn             // i/o connection
+       buf        *bufio.ReadWriter    // buffered rwc
+       hijacked   bool                 // connection has been hijacked by handler
+       tlsState   *tls.ConnectionState // or nil when not using TLS        
 }
 
 // A response represents the server side of an HTTP response.
 type response struct {
        conn          *conn
-       req           *Request          // request for this response
-       chunking      bool              // using chunked transfer encoding for reply body
-       wroteHeader   bool              // reply header has been written
-       wroteContinue bool              // 100 Continue response was written
-       header        map[string]string // reply header parameters
-       written       int64             // number of bytes written in body
-       contentLength int64             // explicitly-declared Content-Length; or -1
-       status        int               // status code passed to WriteHeader
+       req           *Request // request for this response
+       chunking      bool     // using chunked transfer encoding for reply body
+       wroteHeader   bool     // reply header has been written
+       wroteContinue bool     // 100 Continue response was written
+       header        Header   // reply header parameters
+       written       int64    // number of bytes written in body
+       contentLength int64    // explicitly-declared Content-Length; or -1
+       status        int      // status code passed to WriteHeader
 
        // close connection after this reply.  set on request and
        // updated after response from handler if there's a
@@ -125,10 +124,15 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) {
        c.remoteAddr = rwc.RemoteAddr().String()
        c.handler = handler
        c.rwc = rwc
-       _, c.usingTLS = rwc.(*tls.Conn)
        br := bufio.NewReader(rwc)
        bw := bufio.NewWriter(rwc)
        c.buf = bufio.NewReadWriter(br, bw)
+
+       if tlsConn, ok := rwc.(*tls.Conn); ok {
+               c.tlsState = new(tls.ConnectionState)
+               *c.tlsState = tlsConn.ConnectionState()
+       }
+
        return c, nil
 }
 
@@ -168,10 +172,13 @@ func (c *conn) readRequest() (w *response, err os.Error) {
                return nil, err
        }
 
+       req.RemoteAddr = c.remoteAddr
+       req.TLS = c.tlsState
+
        w = new(response)
        w.conn = c
        w.req = req
-       w.header = make(map[string]string)
+       w.header = make(Header)
        w.contentLength = -1
 
        // Expect 100 Continue support
@@ -182,21 +189,10 @@ func (c *conn) readRequest() (w *response, err os.Error) {
        return w, nil
 }
 
-// UsingTLS implements the ResponseWriter.UsingTLS
-func (w *response) UsingTLS() bool {
-       return w.conn.usingTLS
-}
-
-// RemoteAddr implements the ResponseWriter.RemoteAddr method
-func (w *response) RemoteAddr() string { return w.conn.remoteAddr }
-
-// SetHeader implements the ResponseWriter.SetHeader method
-// An empty value removes the header from the map.
-func (w *response) SetHeader(hdr, val string) {
-       w.header[CanonicalHeaderKey(hdr)] = val, val != ""
+func (w *response) Header() Header {
+       return w.header
 }
 
-// WriteHeader implements the ResponseWriter.WriteHeader method
 func (w *response) WriteHeader(code int) {
        if w.conn.hijacked {
                log.Print("http: response.WriteHeader on hijacked connection")
@@ -211,55 +207,55 @@ func (w *response) WriteHeader(code int) {
        if code == StatusNotModified {
                // Must not have body.
                for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} {
-                       if w.header[header] != "" {
+                       if w.header.Get(header) != "" {
                                // TODO: return an error if WriteHeader gets a return parameter
                                // or set a flag on w to make future Writes() write an error page?
                                // for now just log and drop the header.
                                log.Printf("http: StatusNotModified response with header %q defined", header)
-                               w.header[header] = "", false
+                               w.header.Del(header)
                        }
                }
        } else {
                // Default output is HTML encoded in UTF-8.
-               if w.header["Content-Type"] == "" {
-                       w.SetHeader("Content-Type", "text/html; charset=utf-8")
+               if w.header.Get("Content-Type") == "" {
+                       w.header.Set("Content-Type", "text/html; charset=utf-8")
                }
        }
 
-       if w.header["Date"] == "" {
-               w.SetHeader("Date", time.UTC().Format(TimeFormat))
+       if w.header.Get("Date") == "" {
+               w.Header().Set("Date", time.UTC().Format(TimeFormat))
        }
 
        // Check for a explicit (and valid) Content-Length header.
        var hasCL bool
        var contentLength int64
-       if clenStr, ok := w.header["Content-Length"]; ok {
+       if clenStr := w.header.Get("Content-Length"); clenStr != "" {
                var err os.Error
                contentLength, err = strconv.Atoi64(clenStr)
                if err == nil {
                        hasCL = true
                } else {
                        log.Printf("http: invalid Content-Length of %q sent", clenStr)
-                       w.SetHeader("Content-Length", "")
+                       w.header.Del("Content-Length")
                }
        }
 
-       te, hasTE := w.header["Transfer-Encoding"]
+       te := w.header.Get("Transfer-Encoding")
+       hasTE := te != ""
        if hasCL && hasTE && te != "identity" {
                // TODO: return an error if WriteHeader gets a return parameter
                // For now just ignore the Content-Length.
                log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
                        te, contentLength)
-               w.SetHeader("Content-Length", "")
+               w.header.Del("Content-Length")
                hasCL = false
        }
 
-       if w.req.Method == "HEAD" {
+       if w.req.Method == "HEAD" || code == StatusNotModified {
                // do nothing
        } else if hasCL {
-               w.chunking = false
                w.contentLength = contentLength
-               w.SetHeader("Transfer-Encoding", "")
+               w.header.Del("Transfer-Encoding")
        } else if w.req.ProtoAtLeast(1, 1) {
                // HTTP/1.1 or greater: use chunked transfer encoding
                // to avoid closing the connection at EOF.
@@ -267,26 +263,28 @@ func (w *response) WriteHeader(code int) {
                // might have set.  Deal with that as need arises once we have a valid
                // use case.
                w.chunking = true
-               w.SetHeader("Transfer-Encoding", "chunked")
+               w.header.Set("Transfer-Encoding", "chunked")
        } else {
                // HTTP version < 1.1: cannot do chunked transfer
                // encoding and we don't know the Content-Length so
                // signal EOF by closing connection.
                w.closeAfterReply = true
-               w.chunking = false                   // redundant
-               w.SetHeader("Transfer-Encoding", "") // in case already set
+               w.header.Del("Transfer-Encoding") // in case already set
        }
 
        if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) {
                _, connectionHeaderSet := w.header["Connection"]
                if !connectionHeaderSet {
-                       w.SetHeader("Connection", "keep-alive")
+                       w.header.Set("Connection", "keep-alive")
                }
+       } else if !w.req.ProtoAtLeast(1, 1) {
+               // Client did not ask to keep connection alive.
+               w.closeAfterReply = true
        }
 
        // Cannot use Content-Length with non-identity Transfer-Encoding.
        if w.chunking {
-               w.SetHeader("Content-Length", "")
+               w.header.Del("Content-Length")
        }
        if !w.req.ProtoAtLeast(1, 0) {
                return
@@ -301,13 +299,10 @@ func (w *response) WriteHeader(code int) {
                text = "status code " + codestring
        }
        io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n")
-       for k, v := range w.header {
-               io.WriteString(w.conn.buf, k+": "+v+"\r\n")
-       }
+       writeSortedHeader(w.conn.buf, w.header, nil)
        io.WriteString(w.conn.buf, "\r\n")
 }
 
-// Write implements the ResponseWriter.Write method
 func (w *response) Write(data []byte) (n int, err os.Error) {
        if w.conn.hijacked {
                log.Print("http: response.Write on hijacked connection")
@@ -382,7 +377,7 @@ func errorKludge(w *response) {
        msg += " would ignore this error page if this text weren't here.\n"
 
        // Is it text?  ("Content-Type" is always in the map)
-       baseType := strings.Split(w.header["Content-Type"], ";", 2)[0]
+       baseType := strings.Split(w.header.Get("Content-Type"), ";", 2)[0]
        switch baseType {
        case "text/html":
                io.WriteString(w, "<!-- ")
@@ -402,8 +397,8 @@ func (w *response) finishRequest() {
        // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length
        // back, we can make this a keep-alive response ...
        if w.req.wantsHttp10KeepAlive() {
-               _, sentLength := w.header["Content-Length"]
-               if sentLength && w.header["Connection"] == "keep-alive" {
+               sentLength := w.header.Get("Content-Length") != ""
+               if sentLength && w.header.Get("Connection") == "keep-alive" {
                        w.closeAfterReply = false
                }
        }
@@ -425,7 +420,6 @@ func (w *response) finishRequest() {
        }
 }
 
-// Flush implements the ResponseWriter.Flush method.
 func (w *response) Flush() {
        if !w.wroteHeader {
                w.WriteHeader(StatusOK)
@@ -469,8 +463,9 @@ func (c *conn) serve() {
        c.close()
 }
 
-// Hijack impements the ResponseWriter.Hijack method.
-func (w *response) Hijack() (rwc io.ReadWriteCloser, buf *bufio.ReadWriter, err os.Error) {
+// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
+// and a Hijacker.
+func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err os.Error) {
        if w.conn.hijacked {
                return nil, nil, ErrHijacked
        }
@@ -497,7 +492,7 @@ func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) {
 
 // Error replies to the request with the specified error message and HTTP code.
 func Error(w ResponseWriter, error string, code int) {
-       w.SetHeader("Content-Type", "text/plain; charset=utf-8")
+       w.Header().Set("Content-Type", "text/plain; charset=utf-8")
        w.WriteHeader(code)
        fmt.Fprintln(w, error)
 }
@@ -550,7 +545,7 @@ func Redirect(w ResponseWriter, r *Request, url string, code int) {
                }
        }
 
-       w.SetHeader("Location", url)
+       w.Header().Set("Location", url)
        w.WriteHeader(code)
 
        // RFC2616 recommends that a short note "SHOULD" be included in the
@@ -673,7 +668,7 @@ func (mux *ServeMux) match(path string) Handler {
 func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
        // Clean path to canonical form and redirect.
        if p := cleanPath(r.URL.Path); p != r.URL.Path {
-               w.SetHeader("Location", p)
+               w.Header().Set("Location", p)
                w.WriteHeader(StatusMovedPermanently)
                return
        }
@@ -826,7 +821,7 @@ func ListenAndServe(addr string, handler Handler) os.Error {
 //     )
 //
 //     func handler(w http.ResponseWriter, req *http.Request) {
-//             w.SetHeader("Content-Type", "text/plain")
+//             w.Header().Set("Content-Type", "text/plain")
 //             w.Write([]byte("This is an example server.\n"))
 //     }
 //
index 41d639c7e2feca508cc36c2eb3570ed20c4e9888..8a73ead31f9f855c4598232474333a87963fb527 100644 (file)
@@ -9,53 +9,111 @@ import (
        "crypto/tls"
        "encoding/base64"
        "fmt"
+       "io"
+       "log"
        "net"
        "os"
        "strings"
        "sync"
 )
 
-// DefaultTransport is the default implementation of ClientTransport
-// and is used by DefaultClient.  It establishes a new network connection for
-// each call to Do and uses HTTP proxies as directed by the $HTTP_PROXY and
-// $NO_PROXY (or $http_proxy and $no_proxy) environment variables.
-var DefaultTransport ClientTransport = &transport{}
-
-// transport implements http.ClientTranport for the default case,
-// using TCP connections to either the host or a proxy, serving
-// http or https schemes.  In the future this may become public
-// and support options on keep-alive connection duration, pipelining
-// controls, etc.  For now this is simply a port of the old Go code
-// client code to the http.ClientTransport interface.
-type transport struct {
-       // TODO: keep-alives, pipelining, etc using a map from
-       // scheme/host to a connection.  Something like:
-       l        sync.Mutex
-       hostConn map[string]*ClientConn
-}
-
-func (ct *transport) Do(req *Request) (resp *Response, err os.Error) {
+// DefaultTransport is the default implementation of Transport and is
+// used by DefaultClient.  It establishes a new network connection for
+// each call to Do and uses HTTP proxies as directed by the
+// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
+// environment variables.
+var DefaultTransport RoundTripper = &Transport{}
+
+// Transport is an implementation of RoundTripper that supports http,
+// https, and http proxies (for either http or https with CONNECT).
+// Transport can also cache connections for future re-use.
+type Transport struct {
+       lk       sync.Mutex
+       idleConn map[string][]*persistConn
+
+       // TODO: tunables on max cached connections (total, per-server), duration
+       // TODO: optional pipelining
+
+       IgnoreEnvironment bool // don't look at environment variables for proxy configuration
+       DisableKeepAlives bool
+}
+
+// RoundTrip implements the RoundTripper interface.
+func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
+       if req.URL == nil {
+               if req.URL, err = ParseURL(req.RawURL); err != nil {
+                       return
+               }
+       }
        if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
                return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
        }
 
-       addr := req.URL.Host
-       if !hasPort(addr) {
-               addr += ":" + req.URL.Scheme
+       cm, err := t.connectMethodForRequest(req)
+       if err != nil {
+               return nil, err
+       }
+
+       // Get the cached or newly-created connection to either the
+       // host (for http or https), the http proxy, or the http proxy
+       // pre-CONNECTed to https server.  In any case, we'll be ready
+       // to send it requests.
+       pconn, err := t.getConn(cm)
+       if err != nil {
+               return nil, err
        }
 
-       var proxyURL *URL
-       proxyAuth := ""
-       proxy := ""
-       if !matchNoProxy(addr) {
-               proxy = os.Getenv("HTTP_PROXY")
-               if proxy == "" {
-                       proxy = os.Getenv("http_proxy")
+       return pconn.roundTrip(req)
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle in
+// a "keep-alive" state. It does not interrupt any connections currently
+// in use.
+func (t *Transport) CloseIdleConnections() {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               return
+       }
+       for _, conns := range t.idleConn {
+               for _, pconn := range conns {
+                       pconn.close()
                }
        }
+       t.idleConn = nil
+}
 
-       if proxy != "" {
-               proxyURL, err = ParseRequestURL(proxy)
+//
+// Private implementation past this point.
+//
+
+func (t *Transport) getenvEitherCase(k string) string {
+       if t.IgnoreEnvironment {
+               return ""
+       }
+       if v := t.getenv(strings.ToUpper(k)); v != "" {
+               return v
+       }
+       return t.getenv(strings.ToLower(k))
+}
+
+func (t *Transport) getenv(k string) string {
+       if t.IgnoreEnvironment {
+               return ""
+       }
+       return os.Getenv(k)
+}
+
+func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
+       cm := &connectMethod{
+               targetScheme: req.URL.Scheme,
+               targetAddr:   canonicalAddr(req.URL),
+       }
+
+       proxy := t.getenvEitherCase("HTTP_PROXY")
+       if proxy != "" && t.useProxy(cm.targetAddr) {
+               proxyURL, err := ParseRequestURL(proxy)
                if err != nil {
                        return nil, os.ErrorString("invalid proxy address")
                }
@@ -65,83 +123,405 @@ func (ct *transport) Do(req *Request) (resp *Response, err os.Error) {
                                return nil, os.ErrorString("invalid proxy address")
                        }
                }
-               addr = proxyURL.Host
-               proxyInfo := proxyURL.RawUserinfo
-               if proxyInfo != "" {
-                       enc := base64.URLEncoding
-                       encoded := make([]byte, enc.EncodedLen(len(proxyInfo)))
-                       enc.Encode(encoded, []byte(proxyInfo))
-                       proxyAuth = "Basic " + string(encoded)
+               cm.proxyURL = proxyURL
+       }
+       return cm, nil
+}
+
+// proxyAuth returns the Proxy-Authorization header to set
+// on requests, if applicable.
+func (cm *connectMethod) proxyAuth() string {
+       if cm.proxyURL == nil {
+               return ""
+       }
+       proxyInfo := cm.proxyURL.RawUserinfo
+       if proxyInfo != "" {
+               enc := base64.URLEncoding
+               encoded := make([]byte, enc.EncodedLen(len(proxyInfo)))
+               enc.Encode(encoded, []byte(proxyInfo))
+               return "Basic " + string(encoded)
+       }
+       return ""
+}
+
+func (t *Transport) putIdleConn(pconn *persistConn) {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.DisableKeepAlives {
+               pconn.close()
+               return
+       }
+       if pconn.isBroken() {
+               return
+       }
+       key := pconn.cacheKey
+       t.idleConn[key] = append(t.idleConn[key], pconn)
+}
+
+func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               t.idleConn = make(map[string][]*persistConn)
+       }
+       key := cm.String()
+       for {
+               pconns, ok := t.idleConn[key]
+               if !ok {
+                       return nil
+               }
+               if len(pconns) == 1 {
+                       pconn = pconns[0]
+                       t.idleConn[key] = nil, false
+               } else {
+                       // 2 or more cached connections; pop last
+                       // TODO: queue?
+                       pconn = pconns[len(pconns)-1]
+                       t.idleConn[key] = pconns[0 : len(pconns)-1]
+               }
+               if !pconn.isBroken() {
+                       return
                }
        }
+       return
+}
+
+// getConn dials and creates a new persistConn to the target as
+// specified in the connectMethod.  This includes doing a proxy CONNECT
+// and/or setting up TLS.  If this doesn't return an error, the persistConn
+// is ready to write requests to.
+func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
+       if pc := t.getIdleConn(cm); pc != nil {
+               return pc, nil
+       }
 
-       // Connect to server or proxy
-       conn, err := net.Dial("tcp", "", addr)
+       conn, err := net.Dial("tcp", "", cm.addr())
        if err != nil {
                return nil, err
        }
 
-       if req.URL.Scheme == "http" {
-               // Include proxy http header if needed.
-               if proxyAuth != "" {
-                       req.Header.Set("Proxy-Authorization", proxyAuth)
-               }
-       } else { // https
-               if proxyURL != nil {
-                       // Ask proxy for direct connection to server.
-                       // addr defaults above to ":https" but we need to use numbers
-                       addr = req.URL.Host
-                       if !hasPort(addr) {
-                               addr += ":443"
-                       }
-                       fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", addr)
-                       fmt.Fprintf(conn, "Host: %s\r\n", addr)
-                       if proxyAuth != "" {
-                               fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", proxyAuth)
-                       }
-                       fmt.Fprintf(conn, "\r\n")
+       pa := cm.proxyAuth()
 
-                       // Read response.
-                       // Okay to use and discard buffered reader here, because
-                       // TLS server will not speak until spoken to.
-                       br := bufio.NewReader(conn)
-                       resp, err := ReadResponse(br, "CONNECT")
-                       if err != nil {
-                               return nil, err
-                       }
-                       if resp.StatusCode != 200 {
-                               f := strings.Split(resp.Status, " ", 2)
-                               return nil, os.ErrorString(f[1])
+       pconn := &persistConn{
+               t:        t,
+               cacheKey: cm.String(),
+               conn:     conn,
+               reqch:    make(chan requestAndChan, 50),
+       }
+       newClientConnFunc := NewClientConn
+
+       switch {
+       case cm.proxyURL == nil:
+               // Do nothing.
+       case cm.targetScheme == "http":
+               newClientConnFunc = NewProxyClientConn
+               if pa != "" {
+                       pconn.mutateRequestFunc = func(req *Request) {
+                               if req.Header == nil {
+                                       req.Header = make(Header)
+                               }
+                               req.Header.Set("Proxy-Authorization", pa)
                        }
                }
+       case cm.targetScheme == "https":
+               fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", cm.targetAddr)
+               fmt.Fprintf(conn, "Host: %s\r\n", cm.targetAddr)
+               if pa != "" {
+                       fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", pa)
+               }
+               fmt.Fprintf(conn, "\r\n")
 
+               // Read response.
+               // Okay to use and discard buffered reader here, because
+               // TLS server will not speak until spoken to.
+               br := bufio.NewReader(conn)
+               resp, err := ReadResponse(br, "CONNECT")
+               if err != nil {
+                       conn.Close()
+                       return nil, err
+               }
+               if resp.StatusCode != 200 {
+                       f := strings.Split(resp.Status, " ", 2)
+                       conn.Close()
+                       return nil, os.ErrorString(f[1])
+               }
+       }
+
+       if cm.targetScheme == "https" {
                // Initiate TLS and check remote host name against certificate.
                conn = tls.Client(conn, nil)
                if err = conn.(*tls.Conn).Handshake(); err != nil {
                        return nil, err
                }
-               h := req.URL.Host
-               if hasPort(h) {
-                       h = h[:strings.LastIndex(h, ":")]
-               }
-               if err = conn.(*tls.Conn).VerifyHostname(h); err != nil {
+               if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil {
                        return nil, err
                }
+               pconn.conn = conn
        }
 
-       err = req.Write(conn)
-       if err != nil {
-               conn.Close()
-               return nil, err
+       pconn.br = bufio.NewReader(pconn.conn)
+       pconn.cc = newClientConnFunc(conn, pconn.br)
+       pconn.cc.readRes = readResponseWithEOFSignal
+       go pconn.readLoop()
+       return pconn, nil
+}
+
+// useProxy returns true if requests to addr should use a proxy,
+// according to the NO_PROXY or no_proxy environment variable.
+func (t *Transport) useProxy(addr string) bool {
+       if len(addr) == 0 {
+               return true
+       }
+       no_proxy := t.getenvEitherCase("NO_PROXY")
+       if no_proxy == "*" {
+               return false
+       }
+
+       addr = strings.ToLower(strings.TrimSpace(addr))
+       if hasPort(addr) {
+               addr = addr[:strings.LastIndex(addr, ":")]
+       }
+
+       for _, p := range strings.Split(no_proxy, ",", -1) {
+               p = strings.ToLower(strings.TrimSpace(p))
+               if len(p) == 0 {
+                       continue
+               }
+               if hasPort(p) {
+                       p = p[:strings.LastIndex(p, ":")]
+               }
+               if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
+                       return false
+               }
+       }
+       return true
+}
+
+// connectMethod is the map key (in its String form) for keeping persistent
+// TCP connections alive for subsequent HTTP requests.
+//
+// A connect method may be of the following types:
+//
+// Cache key form                Description
+// -----------------             -------------------------
+// ||http|foo.com                http directly to server, no proxy
+// ||https|foo.com               https directly to server, no proxy
+// http://proxy.com|https|foo.com  http to proxy, then CONNECT to foo.com
+// http://proxy.com|http           http to proxy, http to anywhere after that
+//
+// Note: no support to https to the proxy yet.
+//
+type connectMethod struct {
+       proxyURL     *URL   // "" for no proxy, else full proxy URL
+       targetScheme string // "http" or "https"
+       targetAddr   string // Not used if proxy + http targetScheme (4th example in table)
+}
+
+func (ck *connectMethod) String() string {
+       proxyStr := ""
+       if ck.proxyURL != nil {
+               proxyStr = ck.proxyURL.String()
+       }
+       return strings.Join([]string{proxyStr, ck.targetScheme, ck.targetAddr}, "|")
+}
+
+// addr returns the first hop "host:port" to which we need to TCP connect.
+func (cm *connectMethod) addr() string {
+       if cm.proxyURL != nil {
+               return canonicalAddr(cm.proxyURL)
+       }
+       return cm.targetAddr
+}
+
+// tlsHost returns the host name to match against the peer's
+// TLS certificate.
+func (cm *connectMethod) tlsHost() string {
+       h := cm.targetAddr
+       if hasPort(h) {
+               h = h[:strings.LastIndex(h, ":")]
+       }
+       return h
+}
+
+type readResult struct {
+       res *Response // either res or err will be set
+       err os.Error
+}
+
+type writeRequest struct {
+       // Set by client (in pc.roundTrip)
+       req   *Request
+       resch chan *readResult
+
+       // Set by writeLoop if an error writing headers.
+       writeErr os.Error
+}
+
+// persistConn wraps a connection, usually a persistent one
+// (but may be used for non-keep-alive requests as well)
+type persistConn struct {
+       t                 *Transport
+       cacheKey          string // its connectMethod.String()
+       conn              net.Conn
+       cc                *ClientConn
+       br                *bufio.Reader
+       reqch             chan requestAndChan // written by roundTrip(); read by readLoop()
+       mutateRequestFunc func(*Request)      // nil or func to modify each outbound request
+
+       lk                   sync.Mutex // guards numExpectedResponses and broken
+       numExpectedResponses int
+       broken               bool // an error has happened on this connection; marked broken so it's not reused.
+}
+
+func (pc *persistConn) isBroken() bool {
+       pc.lk.Lock()
+       defer pc.lk.Unlock()
+       return pc.broken
+}
+
+func (pc *persistConn) expectingResponse() bool {
+       pc.lk.Lock()
+       defer pc.lk.Unlock()
+       return pc.numExpectedResponses > 0
+}
+
+func (pc *persistConn) readLoop() {
+       alive := true
+       for alive {
+               pb, err := pc.br.Peek(1)
+               if err != nil {
+                       if (err == os.EOF || err == os.EINVAL) && !pc.expectingResponse() {
+                               // Remote side closed on us.  (We probably hit their
+                               // max idle timeout)
+                               pc.close()
+                               return
+                       }
+               }
+               if !pc.expectingResponse() {
+                       log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
+                               string(pb), err)
+                       pc.close()
+                       return
+               }
+
+               rc := <-pc.reqch
+               resp, err := pc.cc.Read(rc.req)
+               if err == nil && !rc.req.Close {
+                       pc.t.putIdleConn(pc)
+               }
+               if err == ErrPersistEOF {
+                       // Succeeded, but we can't send any more
+                       // persistent connections on this again.  We
+                       // hide this error to upstream callers.
+                       alive = false
+                       err = nil
+               } else if err != nil {
+                       alive = false
+               }
+               rc.ch <- responseAndError{resp, err}
+
+               // Wait for the just-returned response body to be fully consumed
+               // before we race and peek on the underlying bufio reader.
+               if alive {
+                       <-resp.Body.(*bodyEOFSignal).ch
+               }
+       }
+}
+
+type responseAndError struct {
+       res *Response
+       err os.Error
+}
+
+type requestAndChan struct {
+       req *Request
+       ch  chan responseAndError
+}
+
+func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
+       if pc.mutateRequestFunc != nil {
+               pc.mutateRequestFunc(req)
        }
 
-       reader := bufio.NewReader(conn)
-       resp, err = ReadResponse(reader, req.Method)
+       pc.lk.Lock()
+       pc.numExpectedResponses++
+       pc.lk.Unlock()
+
+       err = pc.cc.Write(req)
        if err != nil {
-               conn.Close()
-               return nil, err
+               pc.close()
+               return
        }
 
-       resp.Body = readClose{resp.Body, conn}
+       ch := make(chan responseAndError, 1)
+       pc.reqch <- requestAndChan{req, ch}
+       re := <-ch
+       pc.lk.Lock()
+       pc.numExpectedResponses--
+       pc.lk.Unlock()
+       return re.res, re.err
+}
+
+func (pc *persistConn) close() {
+       pc.lk.Lock()
+       defer pc.lk.Unlock()
+       pc.broken = true
+       pc.cc.Close()
+       pc.conn.Close()
+       pc.mutateRequestFunc = nil
+}
+
+var portMap = map[string]string{
+       "http":  "80",
+       "https": "443",
+}
+
+// canonicalAddr returns url.Host but always with a ":port" suffix
+func canonicalAddr(url *URL) string {
+       addr := url.Host
+       if !hasPort(addr) {
+               return addr + ":" + portMap[url.Scheme]
+       }
+       return addr
+}
+
+func responseIsKeepAlive(res *Response) bool {
+       // TODO: implement.  for now just always shutting down the connection.
+       return false
+}
+
+// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces
+// the response body with a bodyEOFSignal-wrapped version.
+func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) {
+       resp, err = ReadResponse(r, requestMethod)
+       if err == nil {
+               resp.Body = &bodyEOFSignal{resp.Body, make(chan bool, 1), false}
+       }
+       return
+}
+
+// bodyEOFSignal wraps a ReadCloser but sends on ch once once
+// the wrapped ReadCloser is fully consumed (including on Close)
+type bodyEOFSignal struct {
+       body io.ReadCloser
+       ch   chan bool
+       done bool
+}
+
+func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) {
+       n, err = es.body.Read(p)
+       if err == os.EOF && !es.done {
+               es.ch <- true
+               es.done = true
+       }
+       return
+}
+
+func (es *bodyEOFSignal) Close() (err os.Error) {
+       err = es.body.Close()
+       if err == nil && !es.done {
+               es.ch <- true
+               es.done = true
+       }
        return
 }
diff --git a/libgo/go/http/transport_test.go b/libgo/go/http/transport_test.go
new file mode 100644 (file)
index 0000000..5c3e1cd
--- /dev/null
@@ -0,0 +1,235 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for transport.go
+
+package http_test
+
+import (
+       "fmt"
+       . "http"
+       "http/httptest"
+       "io/ioutil"
+       "os"
+       "testing"
+       "time"
+)
+
+// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
+//       and then verify that the final 2 responses get errors back.
+
+// hostPortHandler writes back the client's "host:port".
+var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+       if r.FormValue("close") == "true" {
+               w.Header().Set("Connection", "close")
+       }
+       fmt.Fprintf(w, "%s", r.RemoteAddr)
+})
+
+// Two subsequent requests and verify their response is the same.
+// The response from the server is our own IP:port
+func TestTransportKeepAlives(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       for _, disableKeepAlive := range []bool{false, true} {
+               tr := &Transport{DisableKeepAlives: disableKeepAlive}
+               c := &Client{Transport: tr}
+
+               fetch := func(n int) string {
+                       res, _, err := c.Get(ts.URL)
+                       if err != nil {
+                               t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
+                       }
+                       body, err := ioutil.ReadAll(res.Body)
+                       if err != nil {
+                               t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
+                       }
+                       return string(body)
+               }
+
+               body1 := fetch(1)
+               body2 := fetch(2)
+
+               bodiesDiffer := body1 != body2
+               if bodiesDiffer != disableKeepAlive {
+                       t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+                               disableKeepAlive, bodiesDiffer, body1, body2)
+               }
+       }
+}
+
+func TestTransportConnectionCloseOnResponse(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       for _, connectionClose := range []bool{false, true} {
+               tr := &Transport{}
+               c := &Client{Transport: tr}
+
+               fetch := func(n int) string {
+                       req := new(Request)
+                       var err os.Error
+                       req.URL, err = ParseURL(ts.URL + fmt.Sprintf("?close=%v", connectionClose))
+                       if err != nil {
+                               t.Fatalf("URL parse error: %v", err)
+                       }
+                       req.Method = "GET"
+                       req.Proto = "HTTP/1.1"
+                       req.ProtoMajor = 1
+                       req.ProtoMinor = 1
+
+                       res, err := c.Do(req)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+                       }
+                       body, err := ioutil.ReadAll(res.Body)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+                       }
+                       return string(body)
+               }
+
+               body1 := fetch(1)
+               body2 := fetch(2)
+               bodiesDiffer := body1 != body2
+               if bodiesDiffer != connectionClose {
+                       t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+                               connectionClose, bodiesDiffer, body1, body2)
+               }
+       }
+}
+
+func TestTransportConnectionCloseOnRequest(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       for _, connectionClose := range []bool{false, true} {
+               tr := &Transport{}
+               c := &Client{Transport: tr}
+
+               fetch := func(n int) string {
+                       req := new(Request)
+                       var err os.Error
+                       req.URL, err = ParseURL(ts.URL)
+                       if err != nil {
+                               t.Fatalf("URL parse error: %v", err)
+                       }
+                       req.Method = "GET"
+                       req.Proto = "HTTP/1.1"
+                       req.ProtoMajor = 1
+                       req.ProtoMinor = 1
+                       req.Close = connectionClose
+
+                       res, err := c.Do(req)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+                       }
+                       body, err := ioutil.ReadAll(res.Body)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+                       }
+                       return string(body)
+               }
+
+               body1 := fetch(1)
+               body2 := fetch(2)
+               bodiesDiffer := body1 != body2
+               if bodiesDiffer != connectionClose {
+                       t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+                               connectionClose, bodiesDiffer, body1, body2)
+               }
+       }
+}
+
+func TestTransportIdleCacheKeys(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       tr := &Transport{DisableKeepAlives: false}
+       c := &Client{Transport: tr}
+
+       if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+               t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+       }
+
+       if _, _, err := c.Get(ts.URL); err != nil {
+               t.Error(err)
+       }
+
+       keys := tr.IdleConnKeysForTesting()
+       if e, g := 1, len(keys); e != g {
+               t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
+       }
+
+       if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
+               t.Logf("Expected idle cache key %q; got %q", e, keys[0])
+       }
+
+       tr.CloseIdleConnections()
+       if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+               t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+       }
+}
+
+func TestTransportServerClosingUnexpectedly(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       tr := &Transport{}
+       c := &Client{Transport: tr}
+
+       fetch := func(n int) string {
+               res, _, err := c.Get(ts.URL)
+               if err != nil {
+                       t.Fatalf("error in req #%d, GET: %v", n, err)
+               }
+               body, err := ioutil.ReadAll(res.Body)
+               if err != nil {
+                       t.Fatalf("error in req #%d, ReadAll: %v", n, err)
+               }
+               res.Body.Close()
+               return string(body)
+       }
+
+       body1 := fetch(1)
+       body2 := fetch(2)
+
+       ts.CloseClientConnections() // surprise!
+       time.Sleep(25e6)            // idle for a bit (test is inherently racey, but expectedly)
+
+       body3 := fetch(3)
+
+       if body1 != body2 {
+               t.Errorf("expected body1 and body2 to be equal")
+       }
+       if body2 == body3 {
+               t.Errorf("expected body2 and body3 to be different")
+       }
+}
+
+func TestTransportNilURL(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               fmt.Fprintf(w, "Hi")
+       }))
+       defer ts.Close()
+
+       req := new(Request)
+       req.URL = nil // what we're actually testing
+       req.Method = "GET"
+       req.RawURL = ts.URL
+       req.Proto = "HTTP/1.1"
+       req.ProtoMajor = 1
+       req.ProtoMinor = 1
+
+       tr := &Transport{}
+       res, err := tr.RoundTrip(req)
+       if err != nil {
+               t.Fatalf("unexpected RoundTrip error: %v", err)
+       }
+       body, err := ioutil.ReadAll(res.Body)
+       if g, e := string(body), "Hi"; g != e {
+               t.Fatalf("Expected response body of %q; got %q", e, g)
+       }
+}
index fb3fdcda1e673b6da62ad5de96e5e13cbc9b95b5..ed6c310eb44c7e57e633b670cfcc428d876844d9 100644 (file)
@@ -13,11 +13,17 @@ import (
        "sort"
 )
 
+// readAll reads from r until an error or EOF and returns the data it read
+// from the internal buffer allocated with a specified capacity.
+func readAll(r io.Reader, capacity int64) ([]byte, os.Error) {
+       buf := bytes.NewBuffer(make([]byte, 0, capacity))
+       _, err := buf.ReadFrom(r)
+       return buf.Bytes(), err
+}
+
 // ReadAll reads from r until an error or EOF and returns the data it read.
 func ReadAll(r io.Reader) ([]byte, os.Error) {
-       var buf bytes.Buffer
-       _, err := io.Copy(&buf, r)
-       return buf.Bytes(), err
+       return readAll(r, bytes.MinRead)
 }
 
 // ReadFile reads the file named by filename and returns the contents.
@@ -34,16 +40,12 @@ func ReadFile(filename string) ([]byte, os.Error) {
        if err == nil && fi.Size < 2e9 { // Don't preallocate a huge buffer, just in case.
                n = fi.Size
        }
-       // Add a little extra in case Size is zero, and to avoid another allocation after
-       // Read has filled the buffer.
-       n += bytes.MinRead
-       // Pre-allocate the correct size of buffer, then set its size to zero.  The
-       // Buffer will read into the allocated space cheaply.  If the size was wrong,
-       // we'll either waste some space off the end or reallocate as needed, but
+       // As initial capacity for readAll, use n + a little extra in case Size is zero,
+       // and to avoid another allocation after Read has filled the buffer.  The readAll
+       // call will read into its allocated internal buffer cheaply.  If the size was
+       // wrong, we'll either waste some space off the end or reallocate as needed, but
        // in the overwhelmingly common case we'll get it just right.
-       buf := bytes.NewBuffer(make([]byte, 0, n))
-       _, err = buf.ReadFrom(f)
-       return buf.Bytes(), err
+       return readAll(f, n+bytes.MinRead)
 }
 
 // WriteFile writes data to a file named by filename.
@@ -88,3 +90,15 @@ func ReadDir(dirname string) ([]*os.FileInfo, os.Error) {
        sort.Sort(fi)
        return fi, nil
 }
+
+type nopCloser struct {
+       io.Reader
+}
+
+func (nopCloser) Close() os.Error { return nil }
+
+// NopCloser returns a ReadCloser with a no-op Close method wrapping
+// the provided Reader r.
+func NopCloser(r io.Reader) io.ReadCloser {
+       return nopCloser{r}
+}
index c7cc67b1b7462a25c642bfd6db242065ff60e73c..62f8849c0a0b74c964fe59a513c16fac3107293c 100644 (file)
@@ -6,6 +6,7 @@ package ioutil
 
 import (
        "os"
+       "path/filepath"
        "strconv"
 )
 
@@ -46,8 +47,7 @@ func TempFile(dir, prefix string) (f *os.File, err os.Error) {
 
        nconflict := 0
        for i := 0; i < 10000; i++ {
-               // TODO(rsc): use filepath.Join
-               name := dir + "/" + prefix + nextSuffix()
+               name := filepath.Join(dir, prefix+nextSuffix())
                f, err = os.Open(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600)
                if pe, ok := err.(*os.PathError); ok && pe.Error == os.EEXIST {
                        if nconflict++; nconflict > 10 {
@@ -74,8 +74,7 @@ func TempDir(dir, prefix string) (name string, err os.Error) {
 
        nconflict := 0
        for i := 0; i < 10000; i++ {
-               // TODO(rsc): use filepath.Join
-               try := dir + "/" + prefix + nextSuffix()
+               try := filepath.Join(dir, prefix+nextSuffix())
                err = os.Mkdir(try, 0700)
                if pe, ok := err.(*os.PathError); ok && pe.Error == os.EEXIST {
                        if nconflict++; nconflict > 10 {
index 6013ec1d4a42aa87ba5afe640a39550a917b992c..80c62f672c1bd7f9380719700f97afc4ca2bef91 100644 (file)
@@ -7,6 +7,7 @@ package ioutil_test
 import (
        . "io/ioutil"
        "os"
+       "path/filepath"
        "regexp"
        "testing"
 )
@@ -25,7 +26,7 @@ func TestTempFile(t *testing.T) {
        if f != nil {
                f.Close()
                os.Remove(f.Name())
-               re := regexp.MustCompile("^" + regexp.QuoteMeta(dir) + "/ioutil_test[0-9]+$")
+               re := regexp.MustCompile("^" + regexp.QuoteMeta(filepath.Join(dir, "ioutil_test")) + "[0-9]+$")
                if !re.MatchString(f.Name()) {
                        t.Errorf("TempFile(`"+dir+"`, `ioutil_test`) created bad name %s", f.Name())
                }
@@ -45,7 +46,7 @@ func TestTempDir(t *testing.T) {
        }
        if name != "" {
                os.Remove(name)
-               re := regexp.MustCompile("^" + regexp.QuoteMeta(dir) + "/ioutil_test[0-9]+$")
+               re := regexp.MustCompile("^" + regexp.QuoteMeta(filepath.Join(dir, "ioutil_test")) + "[0-9]+$")
                if !re.MatchString(name) {
                        t.Errorf("TempDir(`"+dir+"`, `ioutil_test`) created bad name %s", name)
                }
index df76418b93da55d722fbd3090c14efcf3eadb545..00be8efa2e2058908ad29dc0171fc4d5af6f4536 100644 (file)
@@ -9,7 +9,6 @@ package io
 
 import (
        "os"
-       "runtime"
        "sync"
 )
 
@@ -18,208 +17,114 @@ type pipeResult struct {
        err os.Error
 }
 
-// Shared pipe structure.
+// A pipe is the shared pipe structure underlying PipeReader and PipeWriter.
 type pipe struct {
-       // Reader sends on cr1, receives on cr2.
-       // Writer does the same on cw1, cw2.
-       r1, w1 chan []byte
-       r2, w2 chan pipeResult
-
-       rclose chan os.Error // read close; error to return to writers
-       wclose chan os.Error // write close; error to return to readers
-
-       done chan int // read or write half is done
-}
-
-func (p *pipe) run() {
-       var (
-               rb    []byte      // pending Read
-               wb    []byte      // pending Write
-               wn    int         // amount written so far from wb
-               rerr  os.Error    // if read end is closed, error to send to writers
-               werr  os.Error    // if write end is closed, error to send to readers
-               r1    chan []byte // p.cr1 or nil depending on whether Read is ok
-               w1    chan []byte // p.cw1 or nil depending on whether Write is ok
-               ndone int
-       )
-
-       // Read and Write are enabled at the start.
-       r1 = p.r1
-       w1 = p.w1
-
+       rl    sync.Mutex // gates readers one at a time
+       wl    sync.Mutex // gates writers one at a time
+       l     sync.Mutex // protects remaining fields
+       data  []byte     // data remaining in pending write
+       rwait sync.Cond  // waiting reader
+       wwait sync.Cond  // waiting writer
+       rerr  os.Error   // if reader closed, error to give writes
+       werr  os.Error   // if writer closed, error to give reads
+}
+
+func (p *pipe) read(b []byte) (n int, err os.Error) {
+       // One reader at a time.
+       p.rl.Lock()
+       defer p.rl.Unlock()
+
+       p.l.Lock()
+       defer p.l.Unlock()
        for {
-               select {
-               case <-p.done:
-                       if ndone++; ndone == 2 {
-                               // both reader and writer are gone
-                               // close out any existing i/o
-                               if r1 == nil {
-                                       p.r2 <- pipeResult{0, os.EINVAL}
-                               }
-                               if w1 == nil {
-                                       p.w2 <- pipeResult{0, os.EINVAL}
-                               }
-                               return
-                       }
-                       continue
-               case rerr = <-p.rclose:
-                       if w1 == nil {
-                               // finish pending Write
-                               p.w2 <- pipeResult{wn, rerr}
-                               wn = 0
-                               w1 = p.w1 // allow another Write
-                       }
-                       if r1 == nil {
-                               // Close of read side during Read.
-                               // finish pending Read with os.EINVAL.
-                               p.r2 <- pipeResult{0, os.EINVAL}
-                               r1 = p.r1 // allow another Read
-                       }
-                       continue
-               case werr = <-p.wclose:
-                       if r1 == nil {
-                               // finish pending Read
-                               p.r2 <- pipeResult{0, werr}
-                               r1 = p.r1 // allow another Read
-                       }
-                       if w1 == nil {
-                               // Close of write side during Write.
-                               // finish pending Write with os.EINVAL.
-                               p.w2 <- pipeResult{wn, os.EINVAL}
-                               wn = 0
-                               w1 = p.w1 // allow another Write
-                       }
-                       continue
-               case rb = <-r1:
-                       if werr != nil {
-                               // write end is closed
-                               p.r2 <- pipeResult{0, werr}
-                               continue
-                       }
-                       if rerr != nil {
-                               // read end is closed
-                               p.r2 <- pipeResult{0, os.EINVAL}
-                               continue
-                       }
-                       r1 = nil // disable Read until this one is done
-               case wb = <-w1:
-                       if rerr != nil {
-                               // read end is closed
-                               p.w2 <- pipeResult{0, rerr}
-                               continue
-                       }
-                       if werr != nil {
-                               // write end is closed
-                               p.w2 <- pipeResult{0, os.EINVAL}
-                               continue
-                       }
-                       w1 = nil // disable Write until this one is done
+               if p.rerr != nil {
+                       return 0, os.EINVAL
                }
-
-               if r1 == nil && w1 == nil {
-                       // Have rb and wb.  Execute.
-                       n := copy(rb, wb)
-                       wn += n
-                       wb = wb[n:]
-
-                       // Finish Read.
-                       p.r2 <- pipeResult{n, nil}
-                       r1 = p.r1 // allow another Read
-
-                       // Maybe finish Write.
-                       if len(wb) == 0 {
-                               p.w2 <- pipeResult{wn, nil}
-                               wn = 0
-                               w1 = p.w1 // allow another Write
-                       }
+               if p.data != nil {
+                       break
                }
+               if p.werr != nil {
+                       return 0, p.werr
+               }
+               p.rwait.Wait()
+       }
+       n = copy(b, p.data)
+       p.data = p.data[n:]
+       if len(p.data) == 0 {
+               p.data = nil
+               p.wwait.Signal()
        }
+       return
 }
 
-// Read/write halves of the pipe.
-// They are separate structures for two reasons:
-//  1.  If one end becomes garbage without being Closed,
-//      its finalizer can Close so that the other end
-//      does not hang indefinitely.
-//  2.  Clients cannot use interface conversions on the
-//      read end to find the Write method, and vice versa.
+var zero [0]byte
 
-type pipeHalf struct {
-       c1     chan []byte
-       c2     chan pipeResult
-       cclose chan os.Error
-       done   chan int
-
-       lock   sync.Mutex
-       closed bool
+func (p *pipe) write(b []byte) (n int, err os.Error) {
+       // pipe uses nil to mean not available
+       if b == nil {
+               b = zero[:]
+       }
 
-       io       sync.Mutex
-       ioclosed bool
-}
+       // One writer at a time.
+       p.wl.Lock()
+       defer p.wl.Unlock()
 
-func (p *pipeHalf) rw(data []byte) (n int, err os.Error) {
-       // Run i/o operation.
-       // Check ioclosed flag under lock to make sure we're still allowed to do i/o.
-       p.io.Lock()
-       if p.ioclosed {
-               p.io.Unlock()
-               return 0, os.EINVAL
+       p.l.Lock()
+       defer p.l.Unlock()
+       p.data = b
+       p.rwait.Signal()
+       for {
+               if p.data == nil {
+                       break
+               }
+               if p.rerr != nil {
+                       err = p.rerr
+                       break
+               }
+               if p.werr != nil {
+                       err = os.EINVAL
+               }
+               p.wwait.Wait()
        }
-       p.io.Unlock()
-       p.c1 <- data
-       res := <-p.c2
-       return res.n, res.err
+       n = len(b) - len(p.data)
+       p.data = nil // in case of rerr or werr
+       return
 }
 
-func (p *pipeHalf) close(err os.Error) os.Error {
-       // Close pipe half.
-       // Only first call to close does anything.
-       p.lock.Lock()
-       if p.closed {
-               p.lock.Unlock()
-               return os.EINVAL
+func (p *pipe) rclose(err os.Error) {
+       if err == nil {
+               err = os.EPIPE
        }
-       p.closed = true
-       p.lock.Unlock()
-
-       // First, send the close notification.
-       p.cclose <- err
-
-       // Runner is now responding to rw operations
-       // with os.EINVAL.  Cut off future rw operations
-       // by setting ioclosed flag.
-       p.io.Lock()
-       p.ioclosed = true
-       p.io.Unlock()
-
-       // With ioclosed set, there will be no more rw operations
-       // working on the channels.
-       // Tell the runner we won't be bothering it anymore.
-       p.done <- 1
-
-       // Successfully torn down; can disable finalizer.
-       runtime.SetFinalizer(p, nil)
-
-       return nil
+       p.l.Lock()
+       defer p.l.Unlock()
+       p.rerr = err
+       p.rwait.Signal()
+       p.wwait.Signal()
 }
 
-func (p *pipeHalf) finalizer() {
-       p.close(os.EINVAL)
+func (p *pipe) wclose(err os.Error) {
+       if err == nil {
+               err = os.EOF
+       }
+       p.l.Lock()
+       defer p.l.Unlock()
+       p.werr = err
+       p.rwait.Signal()
+       p.wwait.Signal()
 }
 
-
 // A PipeReader is the read half of a pipe.
 type PipeReader struct {
-       pipeHalf
+       p *pipe
 }
 
 // Read implements the standard Read interface:
 // it reads data from the pipe, blocking until a writer
 // arrives or the write end is closed.
 // If the write end is closed with an error, that error is
-// returned as err; otherwise err is nil.
+// returned as err; otherwise err is os.EOF.
 func (r *PipeReader) Read(data []byte) (n int, err os.Error) {
-       return r.rw(data)
+       return r.p.read(data)
 }
 
 // Close closes the reader; subsequent writes to the
@@ -231,15 +136,13 @@ func (r *PipeReader) Close() os.Error {
 // CloseWithError closes the reader; subsequent writes
 // to the write half of the pipe will return the error err.
 func (r *PipeReader) CloseWithError(err os.Error) os.Error {
-       if err == nil {
-               err = os.EPIPE
-       }
-       return r.close(err)
+       r.p.rclose(err)
+       return nil
 }
 
 // A PipeWriter is the write half of a pipe.
 type PipeWriter struct {
-       pipeHalf
+       p *pipe
 }
 
 // Write implements the standard Write interface:
@@ -248,7 +151,7 @@ type PipeWriter struct {
 // If the read end is closed with an error, that err is
 // returned as err; otherwise err is os.EPIPE.
 func (w *PipeWriter) Write(data []byte) (n int, err os.Error) {
-       return w.rw(data)
+       return w.p.write(data)
 }
 
 // Close closes the writer; subsequent reads from the
@@ -260,10 +163,8 @@ func (w *PipeWriter) Close() os.Error {
 // CloseWithError closes the writer; subsequent reads from the
 // read half of the pipe will return no bytes and the error err.
 func (w *PipeWriter) CloseWithError(err os.Error) os.Error {
-       if err == nil {
-               err = os.EOF
-       }
-       return w.close(err)
+       w.p.wclose(err)
+       return nil
 }
 
 // Pipe creates a synchronous in-memory pipe.
@@ -272,34 +173,10 @@ func (w *PipeWriter) CloseWithError(err os.Error) os.Error {
 // Reads on one end are matched with writes on the other,
 // copying data directly between the two; there is no internal buffering.
 func Pipe() (*PipeReader, *PipeWriter) {
-       p := &pipe{
-               r1:     make(chan []byte),
-               r2:     make(chan pipeResult),
-               w1:     make(chan []byte),
-               w2:     make(chan pipeResult),
-               rclose: make(chan os.Error),
-               wclose: make(chan os.Error),
-               done:   make(chan int),
-       }
-       go p.run()
-
-       // NOTE: Cannot use composite literal here:
-       //      pipeHalf{c1: p.cr1, c2: p.cr2, cclose: p.crclose, cdone: p.cdone}
-       // because this implicitly copies the pipeHalf, which copies the inner mutex.
-
-       r := new(PipeReader)
-       r.c1 = p.r1
-       r.c2 = p.r2
-       r.cclose = p.rclose
-       r.done = p.done
-       runtime.SetFinalizer(r, (*PipeReader).finalizer)
-
-       w := new(PipeWriter)
-       w.c1 = p.w1
-       w.c2 = p.w2
-       w.cclose = p.wclose
-       w.done = p.done
-       runtime.SetFinalizer(w, (*PipeWriter).finalizer)
-
+       p := new(pipe)
+       p.rwait.L = &p.l
+       p.wwait.L = &p.l
+       r := &PipeReader{p}
+       w := &PipeWriter{p}
        return r, w
 }
index 1d855c74c9d493c41b348349354ebc46a4c110d4..0a65a447db999ec90527b56b17de7990f960d02e 100644 (file)
@@ -17,6 +17,7 @@ import (
        "bytes"
        "io"
        "mime"
+       "net/textproto"
        "os"
        "regexp"
        "strings"
@@ -40,7 +41,7 @@ type Part struct {
        // The headers of the body, if any, with the keys canonicalized
        // in the same fashion that the Go http.Request headers are.
        // i.e. "foo-bar" changes case to "Foo-Bar"
-       Header map[string]string
+       Header textproto.MIMEHeader
 
        buffer *bytes.Buffer
        mr     *multiReader
@@ -51,8 +52,8 @@ type Part struct {
 func (p *Part) FormName() string {
        // See http://tools.ietf.org/html/rfc2183 section 2 for EBNF
        // of Content-Disposition value format.
-       v, ok := p.Header["Content-Disposition"]
-       if !ok {
+       v := p.Header.Get("Content-Disposition")
+       if v == "" {
                return ""
        }
        d, params := mime.ParseMediaType(v)
@@ -85,7 +86,7 @@ var devNull = devNullWriter(false)
 
 func newPart(mr *multiReader) (bp *Part, err os.Error) {
        bp = new(Part)
-       bp.Header = make(map[string]string)
+       bp.Header = make(map[string][]string)
        bp.mr = mr
        bp.buffer = new(bytes.Buffer)
        if err = bp.populateHeaders(); err != nil {
@@ -104,10 +105,7 @@ func (bp *Part) populateHeaders() os.Error {
                        return nil
                }
                if matches := headerRegexp.FindStringSubmatch(line); len(matches) == 3 {
-                       key := matches[1]
-                       value := matches[2]
-                       // TODO: canonicalize headers ala http.Request.Header?
-                       bp.Header[key] = value
+                       bp.Header.Add(matches[1], matches[2])
                        continue
                }
                return os.NewError("Unexpected header line found parsing multipart body")
index 7e1ed133ecfa3f7b5bc193c4309a78a4553765c8..1f3d32d7ed6e91bd20698e21fac4b8b3a7988ae4 100644 (file)
@@ -58,7 +58,7 @@ func expectEq(t *testing.T, expected, actual, what string) {
 
 func TestFormName(t *testing.T) {
        p := new(Part)
-       p.Header = make(map[string]string)
+       p.Header = make(map[string][]string)
        tests := [...][2]string{
                {`form-data; name="foo"`, "foo"},
                {` form-data ; name=foo`, "foo"},
@@ -69,7 +69,7 @@ func TestFormName(t *testing.T) {
                {` FORM-DATA ; filename="foo.txt"; name=foo; baz=quux`, "foo"},
        }
        for _, test := range tests {
-               p.Header["Content-Disposition"] = test[0]
+               p.Header.Set("Content-Disposition", test[0])
                expected := test[1]
                actual := p.FormName()
                if actual != expected {
@@ -114,12 +114,15 @@ never read data
                t.Error("Expected part1")
                return
        }
-       if part.Header["Header1"] != "value1" {
+       if part.Header.Get("Header1") != "value1" {
                t.Error("Expected Header1: value")
        }
-       if part.Header["foo-bar"] != "baz" {
+       if part.Header.Get("foo-bar") != "baz" {
                t.Error("Expected foo-bar: baz")
        }
+       if part.Header.Get("Foo-Bar") != "baz" {
+               t.Error("Expected Foo-Bar: baz")
+       }
        buf.Reset()
        io.Copy(buf, part)
        expectEq(t, "My value\r\nThe end.",
@@ -131,7 +134,7 @@ never read data
                t.Error("Expected part2")
                return
        }
-       if part.Header["foo-bar"] != "bazb" {
+       if part.Header.Get("foo-bar") != "bazb" {
                t.Error("Expected foo-bar: bazb")
        }
        buf.Reset()
index d48aefe2cdd8031e459efb236ad4816037e39a6b..7acee149e1f760936b6e850ac7c9a2581d2b6f97 100644 (file)
@@ -2,8 +2,6 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// TODO(rsc): All the prints in this file should go to standard error.
-
 package net
 
 import (
@@ -85,11 +83,12 @@ func (e *InvalidConnError) Timeout() bool   { return false }
 // will the fd be closed.
 
 type pollServer struct {
-       cr, cw   chan *netFD // buffered >= 1
-       pr, pw   *os.File
-       pending  map[int]*netFD
-       poll     *pollster // low-level OS hooks
-       deadline int64     // next deadline (nsec since 1970)
+       cr, cw     chan *netFD // buffered >= 1
+       pr, pw     *os.File
+       poll       *pollster // low-level OS hooks
+       sync.Mutex           // controls pending and deadline
+       pending    map[int]*netFD
+       deadline   int64 // next deadline (nsec since 1970)
 }
 
 func (s *pollServer) AddFD(fd *netFD, mode int) {
@@ -103,10 +102,8 @@ func (s *pollServer) AddFD(fd *netFD, mode int) {
                }
                return
        }
-       if err := s.poll.AddFD(intfd, mode, false); err != nil {
-               panic("pollServer AddFD " + err.String())
-               return
-       }
+
+       s.Lock()
 
        var t int64
        key := intfd << 1
@@ -119,11 +116,27 @@ func (s *pollServer) AddFD(fd *netFD, mode int) {
                t = fd.wdeadline
        }
        s.pending[key] = fd
+       doWakeup := false
        if t > 0 && (s.deadline == 0 || t < s.deadline) {
                s.deadline = t
+               doWakeup = true
+       }
+
+       if err := s.poll.AddFD(intfd, mode, false); err != nil {
+               panic("pollServer AddFD " + err.String())
+       }
+
+       s.Unlock()
+
+       if doWakeup {
+               s.Wakeup()
        }
 }
 
+var wakeupbuf [1]byte
+
+func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) }
+
 func (s *pollServer) LookupFD(fd int, mode int) *netFD {
        key := fd << 1
        if mode == 'w' {
@@ -195,6 +208,8 @@ func (s *pollServer) CheckDeadlines() {
 
 func (s *pollServer) Run() {
        var scratch [100]byte
+       s.Lock()
+       defer s.Unlock()
        for {
                var t = s.deadline
                if t > 0 {
@@ -204,7 +219,7 @@ func (s *pollServer) Run() {
                                continue
                        }
                }
-               fd, mode, err := s.poll.WaitFD(t)
+               fd, mode, err := s.poll.WaitFD(s, t)
                if err != nil {
                        print("pollServer WaitFD: ", err.String(), "\n")
                        return
@@ -215,22 +230,11 @@ func (s *pollServer) Run() {
                        continue
                }
                if fd == s.pr.Fd() {
-                       // Drain our wakeup pipe.
-                       for nn, _ := s.pr.Read(scratch[0:]); nn > 0; {
-                               nn, _ = s.pr.Read(scratch[0:])
-                       }
-                       // Read from channels
-               Update:
-                       for {
-                               select {
-                               case fd := <-s.cr:
-                                       s.AddFD(fd, 'r')
-                               case fd := <-s.cw:
-                                       s.AddFD(fd, 'w')
-                               default:
-                                       break Update
-                               }
-                       }
+                       // Drain our wakeup pipe (we could loop here,
+                       // but it's unlikely that there are more than
+                       // len(scratch) wakeup calls).
+                       s.pr.Read(scratch[0:])
+                       s.CheckDeadlines()
                } else {
                        netfd := s.LookupFD(fd, mode)
                        if netfd == nil {
@@ -242,19 +246,13 @@ func (s *pollServer) Run() {
        }
 }
 
-var wakeupbuf [1]byte
-
-func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) }
-
 func (s *pollServer) WaitRead(fd *netFD) {
-       s.cr <- fd
-       s.Wakeup()
+       s.AddFD(fd, 'r')
        <-fd.cr
 }
 
 func (s *pollServer) WaitWrite(fd *netFD) {
-       s.cw <- fd
-       s.Wakeup()
+       s.AddFD(fd, 'w')
        <-fd.cw
 }
 
index ef86cb17f31eedaee70383593c23168e0bfac0d8..69fbc02c0c100a775919d4afa99853622d2c0a34 100644 (file)
@@ -20,7 +20,17 @@ type pollster struct {
        epfd int
 
        // Events we're already waiting for
+       // Must hold pollServer lock
        events map[int]uint32
+
+       // An event buffer for EpollWait.
+       // Used without a lock, may only be used by WaitFD.
+       waitEventBuf [10]syscall.EpollEvent
+       waitEvents   []syscall.EpollEvent
+
+       // An event buffer for EpollCtl, to avoid a malloc.
+       // Must hold pollServer lock.
+       ctlEvent syscall.EpollEvent
 }
 
 func newpollster() (p *pollster, err os.Error) {
@@ -29,7 +39,7 @@ func newpollster() (p *pollster, err os.Error) {
 
        // The arg to epoll_create is a hint to the kernel
        // about the number of FDs we will care about.
-       // We don't know.
+       // We don't know, and since 2.6.8 the kernel ignores it anyhow.
        if p.epfd, e = syscall.EpollCreate(16); e != 0 {
                return nil, os.NewSyscallError("epoll_create", e)
        }
@@ -38,17 +48,18 @@ func newpollster() (p *pollster, err os.Error) {
 }
 
 func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error {
-       var ev syscall.EpollEvent
+       // pollServer is locked.
+
        var already bool
-       ev.Fd = int32(fd)
-       ev.Events, already = p.events[fd]
+       p.ctlEvent.Fd = int32(fd)
+       p.ctlEvent.Events, already = p.events[fd]
        if !repeat {
-               ev.Events |= syscall.EPOLLONESHOT
+               p.ctlEvent.Events |= syscall.EPOLLONESHOT
        }
        if mode == 'r' {
-               ev.Events |= readFlags
+               p.ctlEvent.Events |= readFlags
        } else {
-               ev.Events |= writeFlags
+               p.ctlEvent.Events |= writeFlags
        }
 
        var op int
@@ -57,14 +68,16 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error {
        } else {
                op = syscall.EPOLL_CTL_ADD
        }
-       if e := syscall.EpollCtl(p.epfd, op, fd, &ev); e != 0 {
+       if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 {
                return os.NewSyscallError("epoll_ctl", e)
        }
-       p.events[fd] = ev.Events
+       p.events[fd] = p.ctlEvent.Events
        return nil
 }
 
 func (p *pollster) StopWaiting(fd int, bits uint) {
+       // pollServer is locked.
+
        events, already := p.events[fd]
        if !already {
                print("Epoll unexpected fd=", fd, "\n")
@@ -82,10 +95,9 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
        // event in the kernel.  Otherwise, delete it.
        events &= ^uint32(bits)
        if int32(events)&^syscall.EPOLLONESHOT != 0 {
-               var ev syscall.EpollEvent
-               ev.Fd = int32(fd)
-               ev.Events = events
-               if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &ev); e != 0 {
+               p.ctlEvent.Fd = int32(fd)
+               p.ctlEvent.Events = events
+               if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 {
                        print("Epoll modify fd=", fd, ": ", os.Errno(e).String(), "\n")
                }
                p.events[fd] = events
@@ -98,6 +110,8 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
 }
 
 func (p *pollster) DelFD(fd int, mode int) {
+       // pollServer is locked.
+
        if mode == 'r' {
                p.StopWaiting(fd, readFlags)
        } else {
@@ -105,24 +119,32 @@ func (p *pollster) DelFD(fd int, mode int) {
        }
 }
 
-func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) {
-       // Get an event.
-       var evarray [1]syscall.EpollEvent
-       ev := &evarray[0]
-       var msec int = -1
-       if nsec > 0 {
-               msec = int((nsec + 1e6 - 1) / 1e6)
-       }
-       n, e := syscall.EpollWait(p.epfd, evarray[0:], msec)
-       for e == syscall.EAGAIN || e == syscall.EINTR {
-               n, e = syscall.EpollWait(p.epfd, evarray[0:], msec)
-       }
-       if e != 0 {
-               return -1, 0, os.NewSyscallError("epoll_wait", e)
-       }
-       if n == 0 {
-               return -1, 0, nil
+func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err os.Error) {
+       for len(p.waitEvents) == 0 {
+               var msec int = -1
+               if nsec > 0 {
+                       msec = int((nsec + 1e6 - 1) / 1e6)
+               }
+
+               s.Unlock()
+               n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec)
+               s.Lock()
+
+               if e != 0 {
+                       if e == syscall.EAGAIN || e == syscall.EINTR {
+                               continue
+                       }
+                       return -1, 0, os.NewSyscallError("epoll_wait", e)
+               }
+               if n == 0 {
+                       return -1, 0, nil
+               }
+               p.waitEvents = p.waitEventBuf[0:n]
        }
+
+       ev := &p.waitEvents[0]
+       p.waitEvents = p.waitEvents[1:]
+
        fd = int(ev.Fd)
 
        if ev.Events&writeFlags != 0 {
index e82224a28364274e55fb51b410481b430053ac18..1904af0d6ad784556a8f2cdd1594a1dea532adcf 100644 (file)
@@ -12,6 +12,8 @@
 
 package net
 
+import "os"
+
 // IP address lengths (bytes).
 const (
        IPv4len = 4
@@ -39,11 +41,7 @@ type IPMask []byte
 // IPv4 address a.b.c.d.
 func IPv4(a, b, c, d byte) IP {
        p := make(IP, IPv6len)
-       for i := 0; i < 10; i++ {
-               p[i] = 0
-       }
-       p[10] = 0xff
-       p[11] = 0xff
+       copy(p, v4InV6Prefix)
        p[12] = a
        p[13] = b
        p[14] = c
@@ -51,6 +49,8 @@ func IPv4(a, b, c, d byte) IP {
        return p
 }
 
+var v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
+
 // IPv4Mask returns the IP mask (in 16-byte form) of the
 // IPv4 mask a.b.c.d.
 func IPv4Mask(a, b, c, d byte) IPMask {
@@ -140,9 +140,24 @@ func (ip IP) DefaultMask() IPMask {
        return nil // not reached
 }
 
+func allFF(b []byte) bool {
+       for _, c := range b {
+               if c != 0xff {
+                       return false
+               }
+       }
+       return true
+}
+
 // Mask returns the result of masking the IP address ip with mask.
 func (ip IP) Mask(mask IPMask) IP {
        n := len(ip)
+       if len(mask) == 16 && len(ip) == 4 && allFF(mask[:12]) {
+               mask = mask[12:]
+       }
+       if len(mask) == 4 && len(ip) == 16 && bytesEqual(ip[:12], v4InV6Prefix) {
+               ip = ip[12:]
+       }
        if n != len(mask) {
                return nil
        }
@@ -245,6 +260,34 @@ func (ip IP) String() string {
        return s
 }
 
+// Equal returns true if ip and x are the same IP address.
+// An IPv4 address and that same address in IPv6 form are
+// considered to be equal.
+func (ip IP) Equal(x IP) bool {
+       if len(ip) == len(x) {
+               return bytesEqual(ip, x)
+       }
+       if len(ip) == 4 && len(x) == 16 {
+               return bytesEqual(x[0:12], v4InV6Prefix) && bytesEqual(ip, x[12:])
+       }
+       if len(ip) == 16 && len(x) == 4 {
+               return bytesEqual(ip[0:12], v4InV6Prefix) && bytesEqual(ip[12:], x)
+       }
+       return false
+}
+
+func bytesEqual(x, y []byte) bool {
+       if len(x) != len(y) {
+               return false
+       }
+       for i, b := range x {
+               if y[i] != b {
+                       return false
+               }
+       }
+       return true
+}
+
 // If mask is a sequence of 1 bits followed by 0 bits,
 // return the number of 1 bits.
 func simpleMaskLength(mask IPMask) int {
@@ -351,7 +394,6 @@ func parseIPv6(s string) IP {
 
        // Loop, parsing hex numbers followed by colon.
        j := 0
-L:
        for j < IPv6len {
                // Hex number.
                n, i1, ok := xtoi(s, i)
@@ -432,15 +474,66 @@ L:
        return p
 }
 
+// A SyntaxError represents a malformed text string and the type of string that was expected.
+type SyntaxError struct {
+       Type string
+       Text string
+}
+
+func (e *SyntaxError) String() string {
+       return "invalid " + e.Type + ": " + e.Text
+}
+
+func parseIP(s string) IP {
+       if p := parseIPv4(s); p != nil {
+               return p
+       }
+       if p := parseIPv6(s); p != nil {
+               return p
+       }
+       return nil
+}
+
 // ParseIP parses s as an IP address, returning the result.
 // The string s can be in dotted decimal ("74.125.19.99")
 // or IPv6 ("2001:4860:0:2001::68") form.
 // If s is not a valid textual representation of an IP address,
 // ParseIP returns nil.
 func ParseIP(s string) IP {
-       p := parseIPv4(s)
-       if p != nil {
+       if p := parseIPv4(s); p != nil {
                return p
        }
        return parseIPv6(s)
 }
+
+// ParseCIDR parses s as a CIDR notation IP address and mask,
+// like "192.168.100.1/24" or "2001:DB8::/48".
+func ParseCIDR(s string) (ip IP, mask IPMask, err os.Error) {
+       i := byteIndex(s, '/')
+       if i < 0 {
+               return nil, nil, &SyntaxError{"CIDR address", s}
+       }
+       ipstr, maskstr := s[:i], s[i+1:]
+       ip = ParseIP(ipstr)
+       nn, i, ok := dtoi(maskstr, 0)
+       if ip == nil || !ok || i != len(maskstr) || nn < 0 || nn > 8*len(ip) {
+               return nil, nil, &SyntaxError{"CIDR address", s}
+       }
+       n := uint(nn)
+       if len(ip) == 4 {
+               v4mask := ^uint32(0xffffffff >> n)
+               mask = IPMask(IPv4(byte(v4mask>>24), byte(v4mask>>16), byte(v4mask>>8), byte(v4mask)))
+               return ip, mask, nil
+       }
+       mask = make(IPMask, 16)
+       for i := 0; i < 16; i++ {
+               if n >= 8 {
+                       mask[i] = 0xff
+                       n -= 8
+                       continue
+               }
+               mask[i] = ^byte(0xff >> n)
+               n = 0
+       }
+       return ip, mask, nil
+}
index dd06050ee50f0f49eed56ade78e4f8f0d126358c..d2cd8efc559753cbc4190e6ee73a07dd4f89022f 100644 (file)
@@ -306,7 +306,7 @@ func (nch *netChan) sender() {
 }
 
 // Receive value from local side for sending to remote side.
-func (nch *netChan) recv() (val reflect.Value, closed bool) {
+func (nch *netChan) recv() (val reflect.Value, ok bool) {
        if nch.dir != Send {
                panic("recv on wrong direction of channel")
        }
@@ -317,7 +317,7 @@ func (nch *netChan) recv() (val reflect.Value, closed bool) {
                nch.space++
        }
        nch.space--
-       return nch.ch.Recv(), nch.ch.Closed()
+       return nch.ch.Recv()
 }
 
 // acked is called when the remote side indicates that
index 55eba0e2e0f61fb5cf5f7780a3000cbc74c1a6da..e91e777e306ee1ac58435ce3d8f6e73940fad5b9 100644 (file)
@@ -181,8 +181,8 @@ func (client *expClient) run() {
 // The header is passed by value to avoid issues of overwriting.
 func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
        for {
-               val, closed := nch.recv()
-               if closed {
+               val, ok := nch.recv()
+               if !ok {
                        if err := client.encode(&hdr, payClosed, nil); err != nil {
                                expLog("error encoding server closed message:", err)
                        }
index 30edcd8123bd1b2f159255aad5963fb6a11b9c63..5db679a3ed6d251003ec76941e6dd38a0fd65e9d 100644 (file)
@@ -213,8 +213,8 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size,
        if dir == Send {
                go func() {
                        for i := 0; n == -1 || i < n; i++ {
-                               val, closed := nch.recv()
-                               if closed {
+                               val, ok := nch.recv()
+                               if !ok {
                                        if err = imp.encode(hdr, payClosed, nil); err != nil {
                                                impLog("error encoding client closed message:", err)
                                        }
index 1c84a9d14dee485b3b2804ecf8b1d4d0ee9758e2..1b5c560872e67e1aaa7a14d23ae5719dc26a6d1f 100644 (file)
@@ -41,8 +41,8 @@ func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) {
                t.Fatal("exportReceive:", err)
        }
        for i := 0; i < count; i++ {
-               v := <-ch
-               if closed(ch) {
+               v, ok := <-ch
+               if !ok {
                        if i != closeCount {
                                t.Errorf("exportReceive expected close at %d; got one at %d", closeCount, i)
                        }
@@ -78,8 +78,8 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
                t.Fatal("importReceive:", err)
        }
        for i := 0; i < count; i++ {
-               v := <-ch
-               if closed(ch) {
+               v, ok := <-ch
+               if !ok {
                        if i != closeCount {
                                t.Errorf("importReceive expected close at %d; got one at %d", closeCount, i)
                        }
@@ -212,8 +212,8 @@ func TestExportHangup(t *testing.T) {
        }
        // Now hang up the channel.  Importer should see it close.
        exp.Hangup("exportedSend")
-       v = <-ich
-       if !closed(ich) {
+       v, ok := <-ich
+       if ok {
                t.Fatal("expected channel to be closed; got value", v)
        }
 }
@@ -242,8 +242,8 @@ func TestImportHangup(t *testing.T) {
        }
        // Now hang up the channel.  Exporter should see it close.
        imp.Hangup("exportedRecv")
-       v = <-ech
-       if !closed(ech) {
+       v, ok := <-ech
+       if ok {
                t.Fatal("expected channel to be closed; got value", v)
        }
 }
index dbdfacc5857744f5426ac550e63ece539e70efdf..9d80ccfbed464d65564cec40b65e73318b9268d3 100644 (file)
@@ -21,27 +21,46 @@ func newProcess(pid, handle int) *Process {
        return p
 }
 
-// StartProcess starts a new process with the program, arguments,
-// and environment specified by name, argv, and envv. The fd array specifies the
-// file descriptors to be set up in the new process: fd[0] will be Unix file
-// descriptor 0 (standard input), fd[1] descriptor 1, and so on.  A nil entry
-// will cause the child to have no open file descriptor with that index.
-// If dir is not empty, the child chdirs into the directory before execing the program.
-func StartProcess(name string, argv []string, envv []string, dir string, fd []*File) (p *Process, err Error) {
-       if envv == nil {
-               envv = Environ()
+// ProcAttr holds the attributes that will be applied to a new process
+// started by StartProcess.
+type ProcAttr struct {
+       // If Dir is non-empty, the child changes into the directory before
+       // creating the process.
+       Dir string
+       // If Env is non-nil, it gives the environment variables for the
+       // new process in the form returned by Environ.
+       // If it is nil, the result of Environ will be used.
+       Env []string
+       // Files specifies the open files inherited by the new process.  The
+       // first three entries correspond to standard input, standard output, and
+       // standard error.  An implementation may support additional entries,
+       // depending on the underlying operating system.  A nil entry corresponds
+       // to that file being closed when the process starts.
+       Files []*File
+}
+
+// StartProcess starts a new process with the program, arguments and attributes
+// specified by name, argv and attr.
+func StartProcess(name string, argv []string, attr *ProcAttr) (p *Process, err Error) {
+       sysattr := &syscall.ProcAttr{
+               Dir: attr.Dir,
+               Env: attr.Env,
+       }
+       if sysattr.Env == nil {
+               sysattr.Env = Environ()
        }
        // Create array of integer (system) fds.
-       intfd := make([]int, len(fd))
-       for i, f := range fd {
+       intfd := make([]int, len(attr.Files))
+       for i, f := range attr.Files {
                if f == nil {
                        intfd[i] = -1
                } else {
                        intfd[i] = f.Fd()
                }
        }
+       sysattr.Files = intfd
 
-       pid, h, e := syscall.StartProcess(name, argv, envv, dir, intfd)
+       pid, h, e := syscall.StartProcess(name, argv, sysattr)
        if e != 0 {
                return nil, &PathError{"fork/exec", name, Errno(e)}
        }
index 332edcb644d2eaff463496a32d65b4cd0156c0ba..79c3bfa36e489e8af75fbc29f3fd57033e080058 100644 (file)
@@ -35,6 +35,7 @@ func TestInotifyEvents(t *testing.T) {
        // Receive events on the event channel on a separate goroutine
        eventstream := watcher.Event
        var eventsReceived = 0
+       done := make(chan bool)
        go func() {
                for event := range eventstream {
                        // Only count relevant events
@@ -45,6 +46,7 @@ func TestInotifyEvents(t *testing.T) {
                                t.Logf("unexpected event received: %s", event)
                        }
                }
+               done <- true
        }()
 
        // Create a file
@@ -64,16 +66,12 @@ func TestInotifyEvents(t *testing.T) {
        t.Log("calling Close()")
        watcher.Close()
        t.Log("waiting for the event channel to become closed...")
-       var i = 0
-       for !closed(eventstream) {
-               if i >= 20 {
-                       t.Fatal("event stream was not closed after 1 second, as expected")
-               }
-               t.Log("waiting for 50 ms...")
-               time.Sleep(50e6) // 50 ms
-               i++
+       select {
+       case <-done:
+               t.Log("event channel closed")
+       case <-time.After(1e9):
+               t.Fatal("event stream was not closed after 1 second")
        }
-       t.Log("event channel closed")
 }
 
 
index be5f4824e4fcdc8d58d9791074d881e9daffe050..bb1b8e31893150aa0a7d0774dfc956ce6f63baf2 100644 (file)
@@ -10,6 +10,7 @@ import (
        "io"
        "io/ioutil"
        . "os"
+       "path/filepath"
        "strings"
        "syscall"
        "testing"
@@ -405,25 +406,13 @@ func TestRename(t *testing.T) {
        }
 }
 
-func TestForkExec(t *testing.T) {
-       var cmd, adir, expect string
-       var args []string
+func exec(t *testing.T, dir, cmd string, args []string, expect string) {
        r, w, err := Pipe()
        if err != nil {
                t.Fatalf("Pipe: %v", err)
        }
-       if syscall.OS == "windows" {
-               cmd = Getenv("COMSPEC")
-               args = []string{Getenv("COMSPEC"), "/c cd"}
-               adir = Getenv("SystemRoot")
-               expect = Getenv("SystemRoot") + "\r\n"
-       } else {
-               cmd = "/bin/pwd"
-               args = []string{"pwd"}
-               adir = "/"
-               expect = "/\n"
-       }
-       p, err := StartProcess(cmd, args, nil, adir, []*File{nil, w, Stderr})
+       attr := &ProcAttr{Dir: dir, Files: []*File{nil, w, Stderr}}
+       p, err := StartProcess(cmd, args, attr)
        if err != nil {
                t.Fatalf("StartProcess: %v", err)
        }
@@ -434,12 +423,34 @@ func TestForkExec(t *testing.T) {
        io.Copy(&b, r)
        output := b.String()
        if output != expect {
-               args[0] = cmd
-               t.Errorf("exec %q returned %q wanted %q", strings.Join(args, " "), output, expect)
+               t.Errorf("exec %q returned %q wanted %q",
+                       strings.Join(append([]string{cmd}, args...), " "), output, expect)
        }
        p.Wait(0)
 }
 
+func TestStartProcess(t *testing.T) {
+       var dir, cmd, le string
+       var args []string
+       if syscall.OS == "windows" {
+               le = "\r\n"
+               cmd = Getenv("COMSPEC")
+               dir = Getenv("SystemRoot")
+               args = []string{"/c", "cd"}
+       } else {
+               le = "\n"
+               cmd = "/bin/pwd"
+               dir = "/"
+               args = []string{}
+       }
+       cmddir, cmdbase := filepath.Split(cmd)
+       args = append([]string{cmdbase}, args...)
+       // Test absolute executable path.
+       exec(t, dir, cmd, args, dir+le)
+       // Test relative executable path.
+       exec(t, cmddir, cmdbase, args, filepath.Clean(cmddir)+le)
+}
+
 func checkMode(t *testing.T, path string, mode uint32) {
        dir, err := Stat(path)
        if err != nil {
@@ -747,7 +758,7 @@ func run(t *testing.T, cmd []string) string {
        if err != nil {
                t.Fatal(err)
        }
-       p, err := StartProcess("/bin/hostname", []string{"hostname"}, nil, "/", []*File{nil, w, Stderr})
+       p, err := StartProcess("/bin/hostname", []string{"hostname"}, &ProcAttr{Files: []*File{nil, w, Stderr}})
        if err != nil {
                t.Fatal(err)
        }
diff --git a/libgo/go/path/filepath/match.go b/libgo/go/path/filepath/match.go
new file mode 100644 (file)
index 0000000..ad4053f
--- /dev/null
@@ -0,0 +1,282 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package filepath
+
+import (
+       "os"
+       "sort"
+       "strings"
+       "utf8"
+)
+
+var ErrBadPattern = os.NewError("syntax error in pattern")
+
+// Match returns true if name matches the shell file name pattern.
+// The pattern syntax is:
+//
+//     pattern:
+//             { term }
+//     term:
+//             '*'         matches any sequence of non-Separator characters
+//             '?'         matches any single non-Separator character
+//             '[' [ '^' ] { character-range } ']'
+//                         character class (must be non-empty)
+//             c           matches character c (c != '*', '?', '\\', '[')
+//             '\\' c      matches character c
+//
+//     character-range:
+//             c           matches character c (c != '\\', '-', ']')
+//             '\\' c      matches character c
+//             lo '-' hi   matches character c for lo <= c <= hi
+//
+// Match requires pattern to match all of name, not just a substring.
+// The only possible error return is when pattern is malformed.
+//
+func Match(pattern, name string) (matched bool, err os.Error) {
+Pattern:
+       for len(pattern) > 0 {
+               var star bool
+               var chunk string
+               star, chunk, pattern = scanChunk(pattern)
+               if star && chunk == "" {
+                       // Trailing * matches rest of string unless it has a /.
+                       return strings.Index(name, string(Separator)) < 0, nil
+               }
+               // Look for match at current position.
+               t, ok, err := matchChunk(chunk, name)
+               // if we're the last chunk, make sure we've exhausted the name
+               // otherwise we'll give a false result even if we could still match
+               // using the star
+               if ok && (len(t) == 0 || len(pattern) > 0) {
+                       name = t
+                       continue
+               }
+               if err != nil {
+                       return false, err
+               }
+               if star {
+                       // Look for match skipping i+1 bytes.
+                       // Cannot skip /.
+                       for i := 0; i < len(name) && name[i] != Separator; i++ {
+                               t, ok, err := matchChunk(chunk, name[i+1:])
+                               if ok {
+                                       // if we're the last chunk, make sure we exhausted the name
+                                       if len(pattern) == 0 && len(t) > 0 {
+                                               continue
+                                       }
+                                       name = t
+                                       continue Pattern
+                               }
+                               if err != nil {
+                                       return false, err
+                               }
+                       }
+               }
+               return false, nil
+       }
+       return len(name) == 0, nil
+}
+
+// scanChunk gets the next segment of pattern, which is a non-star string
+// possibly preceded by a star.
+func scanChunk(pattern string) (star bool, chunk, rest string) {
+       for len(pattern) > 0 && pattern[0] == '*' {
+               pattern = pattern[1:]
+               star = true
+       }
+       inrange := false
+       var i int
+Scan:
+       for i = 0; i < len(pattern); i++ {
+               switch pattern[i] {
+               case '\\':
+                       // error check handled in matchChunk: bad pattern.
+                       if i+1 < len(pattern) {
+                               i++
+                       }
+               case '[':
+                       inrange = true
+               case ']':
+                       inrange = false
+               case '*':
+                       if !inrange {
+                               break Scan
+                       }
+               }
+       }
+       return star, pattern[0:i], pattern[i:]
+}
+
+// matchChunk checks whether chunk matches the beginning of s.
+// If so, it returns the remainder of s (after the match).
+// Chunk is all single-character operators: literals, char classes, and ?.
+func matchChunk(chunk, s string) (rest string, ok bool, err os.Error) {
+       for len(chunk) > 0 {
+               if len(s) == 0 {
+                       return
+               }
+               switch chunk[0] {
+               case '[':
+                       // character class
+                       r, n := utf8.DecodeRuneInString(s)
+                       s = s[n:]
+                       chunk = chunk[1:]
+                       // possibly negated
+                       notNegated := true
+                       if len(chunk) > 0 && chunk[0] == '^' {
+                               notNegated = false
+                               chunk = chunk[1:]
+                       }
+                       // parse all ranges
+                       match := false
+                       nrange := 0
+                       for {
+                               if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 {
+                                       chunk = chunk[1:]
+                                       break
+                               }
+                               var lo, hi int
+                               if lo, chunk, err = getEsc(chunk); err != nil {
+                                       return
+                               }
+                               hi = lo
+                               if chunk[0] == '-' {
+                                       if hi, chunk, err = getEsc(chunk[1:]); err != nil {
+                                               return
+                                       }
+                               }
+                               if lo <= r && r <= hi {
+                                       match = true
+                               }
+                               nrange++
+                       }
+                       if match != notNegated {
+                               return
+                       }
+
+               case '?':
+                       if s[0] == Separator {
+                               return
+                       }
+                       _, n := utf8.DecodeRuneInString(s)
+                       s = s[n:]
+                       chunk = chunk[1:]
+
+               case '\\':
+                       chunk = chunk[1:]
+                       if len(chunk) == 0 {
+                               err = ErrBadPattern
+                               return
+                       }
+                       fallthrough
+
+               default:
+                       if chunk[0] != s[0] {
+                               return
+                       }
+                       s = s[1:]
+                       chunk = chunk[1:]
+               }
+       }
+       return s, true, nil
+}
+
+// getEsc gets a possibly-escaped character from chunk, for a character class.
+func getEsc(chunk string) (r int, nchunk string, err os.Error) {
+       if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' {
+               err = ErrBadPattern
+               return
+       }
+       if chunk[0] == '\\' {
+               chunk = chunk[1:]
+               if len(chunk) == 0 {
+                       err = ErrBadPattern
+                       return
+               }
+       }
+       r, n := utf8.DecodeRuneInString(chunk)
+       if r == utf8.RuneError && n == 1 {
+               err = ErrBadPattern
+       }
+       nchunk = chunk[n:]
+       if len(nchunk) == 0 {
+               err = ErrBadPattern
+       }
+       return
+}
+
+// Glob returns the names of all files matching pattern or nil
+// if there is no matching file. The syntax of patterns is the same
+// as in Match. The pattern may describe hierarchical names such as
+// /usr/*/bin/ed (assuming the Separator is '/').
+//
+func Glob(pattern string) (matches []string) {
+       if !hasMeta(pattern) {
+               if _, err := os.Stat(pattern); err == nil {
+                       return []string{pattern}
+               }
+               return nil
+       }
+
+       dir, file := Split(pattern)
+       switch dir {
+       case "":
+               dir = "."
+       case string(Separator):
+               // nothing
+       default:
+               dir = dir[0 : len(dir)-1] // chop off trailing separator
+       }
+
+       if hasMeta(dir) {
+               for _, d := range Glob(dir) {
+                       matches = glob(d, file, matches)
+               }
+       } else {
+               return glob(dir, file, nil)
+       }
+       return matches
+}
+
+// glob searches for files matching pattern in the directory dir
+// and appends them to matches.
+func glob(dir, pattern string, matches []string) []string {
+       fi, err := os.Stat(dir)
+       if err != nil {
+               return nil
+       }
+       if !fi.IsDirectory() {
+               return matches
+       }
+       d, err := os.Open(dir, os.O_RDONLY, 0666)
+       if err != nil {
+               return nil
+       }
+       defer d.Close()
+
+       names, err := d.Readdirnames(-1)
+       if err != nil {
+               return nil
+       }
+       sort.SortStrings(names)
+
+       for _, n := range names {
+               matched, err := Match(pattern, n)
+               if err != nil {
+                       return matches
+               }
+               if matched {
+                       matches = append(matches, Join(dir, n))
+               }
+       }
+       return matches
+}
+
+// hasMeta returns true if path contains any of the magic characters
+// recognized by Match.
+func hasMeta(path string) bool {
+       // TODO(niemeyer): Should other magic characters be added here?
+       return strings.IndexAny(path, "*?[") >= 0
+}
diff --git a/libgo/go/path/filepath/match_test.go b/libgo/go/path/filepath/match_test.go
new file mode 100644 (file)
index 0000000..a1e6316
--- /dev/null
@@ -0,0 +1,117 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package filepath_test
+
+import (
+       "os"
+       "path/filepath"
+       "testing"
+       "runtime"
+)
+
+type MatchTest struct {
+       pattern, s string
+       match      bool
+       err        os.Error
+}
+
+var matchTests = []MatchTest{
+       {"abc", "abc", true, nil},
+       {"*", "abc", true, nil},
+       {"*c", "abc", true, nil},
+       {"a*", "a", true, nil},
+       {"a*", "abc", true, nil},
+       {"a*", "ab/c", false, nil},
+       {"a*/b", "abc/b", true, nil},
+       {"a*/b", "a/c/b", false, nil},
+       {"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil},
+       {"a*b*c*d*e*/f", "axbxcxdxexxx/f", true, nil},
+       {"a*b*c*d*e*/f", "axbxcxdxe/xxx/f", false, nil},
+       {"a*b*c*d*e*/f", "axbxcxdxexxx/fff", false, nil},
+       {"a*b?c*x", "abxbbxdbxebxczzx", true, nil},
+       {"a*b?c*x", "abxbbxdbxebxczzy", false, nil},
+       {"ab[c]", "abc", true, nil},
+       {"ab[b-d]", "abc", true, nil},
+       {"ab[e-g]", "abc", false, nil},
+       {"ab[^c]", "abc", false, nil},
+       {"ab[^b-d]", "abc", false, nil},
+       {"ab[^e-g]", "abc", true, nil},
+       {"a\\*b", "a*b", true, nil},
+       {"a\\*b", "ab", false, nil},
+       {"a?b", "a☺b", true, nil},
+       {"a[^a]b", "a☺b", true, nil},
+       {"a???b", "a☺b", false, nil},
+       {"a[^a][^a][^a]b", "a☺b", false, nil},
+       {"[a-ζ]*", "α", true, nil},
+       {"*[a-ζ]", "A", false, nil},
+       {"a?b", "a/b", false, nil},
+       {"a*b", "a/b", false, nil},
+       {"[\\]a]", "]", true, nil},
+       {"[\\-]", "-", true, nil},
+       {"[x\\-]", "x", true, nil},
+       {"[x\\-]", "-", true, nil},
+       {"[x\\-]", "z", false, nil},
+       {"[\\-x]", "x", true, nil},
+       {"[\\-x]", "-", true, nil},
+       {"[\\-x]", "a", false, nil},
+       {"[]a]", "]", false, filepath.ErrBadPattern},
+       {"[-]", "-", false, filepath.ErrBadPattern},
+       {"[x-]", "x", false, filepath.ErrBadPattern},
+       {"[x-]", "-", false, filepath.ErrBadPattern},
+       {"[x-]", "z", false, filepath.ErrBadPattern},
+       {"[-x]", "x", false, filepath.ErrBadPattern},
+       {"[-x]", "-", false, filepath.ErrBadPattern},
+       {"[-x]", "a", false, filepath.ErrBadPattern},
+       {"\\", "a", false, filepath.ErrBadPattern},
+       {"[a-b-c]", "a", false, filepath.ErrBadPattern},
+       {"*x", "xxx", true, nil},
+}
+
+func TestMatch(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               // XXX: Don't pass for windows.
+               return
+       }
+       for _, tt := range matchTests {
+               ok, err := filepath.Match(tt.pattern, tt.s)
+               if ok != tt.match || err != tt.err {
+                       t.Errorf("Match(%#q, %#q) = %v, %v want %v, nil", tt.pattern, tt.s, ok, err, tt.match)
+               }
+       }
+}
+
+// contains returns true if vector contains the string s.
+func contains(vector []string, s string) bool {
+       s = filepath.ToSlash(s)
+       for _, elem := range vector {
+               if elem == s {
+                       return true
+               }
+       }
+       return false
+}
+
+var globTests = []struct {
+       pattern, result string
+}{
+       {"match.go", "match.go"},
+       {"mat?h.go", "match.go"},
+       {"*", "match.go"},
+       // Does not work in gccgo test environment.
+       // {"../*/match.go", "../filepath/match.go"},
+}
+
+func TestGlob(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               // XXX: Don't pass for windows.
+               return
+       }
+       for _, tt := range globTests {
+               matches := filepath.Glob(tt.pattern)
+               if !contains(matches, tt.result) {
+                       t.Errorf("Glob(%#q) = %#v want %v", tt.pattern, matches, tt.result)
+               }
+       }
+}
diff --git a/libgo/go/path/filepath/path.go b/libgo/go/path/filepath/path.go
new file mode 100644 (file)
index 0000000..6cd6cf2
--- /dev/null
@@ -0,0 +1,335 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The filepath package implements utility routines for manipulating
+// filename paths in a way compatible with the target operating
+// system-defined file paths.
+package filepath
+
+import (
+       "bytes"
+       "os"
+       "sort"
+       "strings"
+)
+
+const (
+       SeparatorString     = string(Separator)
+       ListSeparatorString = string(ListSeparator)
+)
+
+// Clean returns the shortest path name equivalent to path
+// by purely lexical processing.  It applies the following rules
+// iteratively until no further processing can be done:
+//
+//     1. Replace multiple Separator elements with a single one.
+//     2. Eliminate each . path name element (the current directory).
+//     3. Eliminate each inner .. path name element (the parent directory)
+//        along with the non-.. element that precedes it.
+//     4. Eliminate .. elements that begin a rooted path:
+//        that is, replace "/.." by "/" at the beginning of a path,
+//         assuming Separator is '/'.
+//
+// If the result of this process is an empty string, Clean
+// returns the string ".".
+//
+// See also Rob Pike, ``Lexical File Names in Plan 9 or
+// Getting Dot-Dot right,''
+// http://plan9.bell-labs.com/sys/doc/lexnames.html
+func Clean(path string) string {
+       if path == "" {
+               return "."
+       }
+
+       rooted := IsAbs(path)
+
+       // Invariants:
+       //      reading from path; r is index of next byte to process.
+       //      writing to buf; w is index of next byte to write.
+       //      dotdot is index in buf where .. must stop, either because
+       //              it is the leading slash or it is a leading ../../.. prefix.
+       prefix := volumeName(path)
+       path = path[len(prefix):]
+       n := len(path)
+       buf := []byte(path)
+       r, w, dotdot := 0, 0, 0
+       if rooted {
+               buf[0] = Separator
+               r, w, dotdot = 1, 1, 1
+       }
+
+       for r < n {
+               switch {
+               case isSeparator(path[r]):
+                       // empty path element
+                       r++
+               case path[r] == '.' && (r+1 == n || isSeparator(path[r+1])):
+                       // . element
+                       r++
+               case path[r] == '.' && path[r+1] == '.' && (r+2 == n || isSeparator(path[r+2])):
+                       // .. element: remove to last separator
+                       r += 2
+                       switch {
+                       case w > dotdot:
+                               // can backtrack
+                               w--
+                               for w > dotdot && !isSeparator(buf[w]) {
+                                       w--
+                               }
+                       case !rooted:
+                               // cannot backtrack, but not rooted, so append .. element.
+                               if w > 0 {
+                                       buf[w] = Separator
+                                       w++
+                               }
+                               buf[w] = '.'
+                               w++
+                               buf[w] = '.'
+                               w++
+                               dotdot = w
+                       }
+               default:
+                       // real path element.
+                       // add slash if needed
+                       if rooted && w != 1 || !rooted && w != 0 {
+                               buf[w] = Separator
+                               w++
+                       }
+                       // copy element
+                       for ; r < n && !isSeparator(path[r]); r++ {
+                               buf[w] = path[r]
+                               w++
+                       }
+               }
+       }
+
+       // Turn empty string into "."
+       if w == 0 {
+               buf[w] = '.'
+               w++
+       }
+
+       return prefix + string(buf[0:w])
+}
+
+// ToSlash returns the result of replacing each separator character
+// in path with a slash ('/') character.
+func ToSlash(path string) string {
+       if Separator == '/' {
+               return path
+       }
+       return strings.Replace(path, SeparatorString, "/", -1)
+}
+
+// FromSlash returns the result of replacing each slash ('/') character
+// in path with a separator character.
+func FromSlash(path string) string {
+       if Separator == '/' {
+               return path
+       }
+       return strings.Replace(path, "/", SeparatorString, -1)
+}
+
+// SplitList splits a list of paths joined by the OS-specific ListSeparator.
+func SplitList(path string) []string {
+       if path == "" {
+               return []string{}
+       }
+       return strings.Split(path, ListSeparatorString, -1)
+}
+
+// Split splits path immediately following the final Separator,
+// partitioning it into a directory and a file name components.
+// If there are no separators in path, Split returns an empty base
+// and file set to path.
+func Split(path string) (dir, file string) {
+       i := len(path) - 1
+       for i >= 0 && !isSeparator(path[i]) {
+               i--
+       }
+       return path[:i+1], path[i+1:]
+}
+
+// Join joins any number of path elements into a single path, adding
+// a Separator if necessary.  All empty strings are ignored.
+func Join(elem ...string) string {
+       for i, e := range elem {
+               if e != "" {
+                       return Clean(strings.Join(elem[i:], SeparatorString))
+               }
+       }
+       return ""
+}
+
+// Ext returns the file name extension used by path.
+// The extension is the suffix beginning at the final dot
+// in the final element of path; it is empty if there is
+// no dot.
+func Ext(path string) string {
+       for i := len(path) - 1; i >= 0 && !isSeparator(path[i]); i-- {
+               if path[i] == '.' {
+                       return path[i:]
+               }
+       }
+       return ""
+}
+
+// EvalSymlinks returns the path name after the evaluation of any symbolic
+// links.
+// If path is relative it will be evaluated relative to the current directory.
+func EvalSymlinks(path string) (string, os.Error) {
+       const maxIter = 255
+       originalPath := path
+       // consume path by taking each frontmost path element,
+       // expanding it if it's a symlink, and appending it to b
+       var b bytes.Buffer
+       for n := 0; path != ""; n++ {
+               if n > maxIter {
+                       return "", os.NewError("EvalSymlinks: too many links in " + originalPath)
+               }
+
+               // find next path component, p
+               i := strings.IndexRune(path, Separator)
+               var p string
+               if i == -1 {
+                       p, path = path, ""
+               } else {
+                       p, path = path[:i], path[i+1:]
+               }
+
+               if p == "" {
+                       if b.Len() == 0 {
+                               // must be absolute path
+                               b.WriteRune(Separator)
+                       }
+                       continue
+               }
+
+               fi, err := os.Lstat(b.String() + p)
+               if err != nil {
+                       return "", err
+               }
+               if !fi.IsSymlink() {
+                       b.WriteString(p)
+                       if path != "" {
+                               b.WriteRune(Separator)
+                       }
+                       continue
+               }
+
+               // it's a symlink, put it at the front of path
+               dest, err := os.Readlink(b.String() + p)
+               if err != nil {
+                       return "", err
+               }
+               if IsAbs(dest) {
+                       b.Reset()
+               }
+               path = dest + SeparatorString + path
+       }
+       return Clean(b.String()), nil
+}
+
+// Visitor methods are invoked for corresponding file tree entries
+// visited by Walk. The parameter path is the full path of f relative
+// to root.
+type Visitor interface {
+       VisitDir(path string, f *os.FileInfo) bool
+       VisitFile(path string, f *os.FileInfo)
+}
+
+func walk(path string, f *os.FileInfo, v Visitor, errors chan<- os.Error) {
+       if !f.IsDirectory() {
+               v.VisitFile(path, f)
+               return
+       }
+
+       if !v.VisitDir(path, f) {
+               return // skip directory entries
+       }
+
+       list, err := readDir(path)
+       if err != nil {
+               if errors != nil {
+                       errors <- err
+               }
+       }
+
+       for _, e := range list {
+               walk(Join(path, e.Name), e, v, errors)
+       }
+}
+
+// readDir reads the directory named by dirname and returns
+// a list of sorted directory entries.
+// Copied from io/ioutil to avoid the circular import.
+func readDir(dirname string) ([]*os.FileInfo, os.Error) {
+       f, err := os.Open(dirname, os.O_RDONLY, 0)
+       if err != nil {
+               return nil, err
+       }
+       list, err := f.Readdir(-1)
+       f.Close()
+       if err != nil {
+               return nil, err
+       }
+       fi := make(fileInfoList, len(list))
+       for i := range list {
+               fi[i] = &list[i]
+       }
+       sort.Sort(fi)
+       return fi, nil
+}
+
+// A dirList implements sort.Interface.
+type fileInfoList []*os.FileInfo
+
+func (f fileInfoList) Len() int           { return len(f) }
+func (f fileInfoList) Less(i, j int) bool { return f[i].Name < f[j].Name }
+func (f fileInfoList) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
+
+// Walk walks the file tree rooted at root, calling v.VisitDir or
+// v.VisitFile for each directory or file in the tree, including root.
+// If v.VisitDir returns false, Walk skips the directory's entries;
+// otherwise it invokes itself for each directory entry in sorted order.
+// An error reading a directory does not abort the Walk.
+// If errors != nil, Walk sends each directory read error
+// to the channel.  Otherwise Walk discards the error.
+func Walk(root string, v Visitor, errors chan<- os.Error) {
+       f, err := os.Lstat(root)
+       if err != nil {
+               if errors != nil {
+                       errors <- err
+               }
+               return // can't progress
+       }
+       walk(root, f, v, errors)
+}
+
+// Base returns the last element of path.
+// Trailing path separators are removed before extracting the last element.
+// If the path is empty, Base returns ".".
+// If the path consists entirely of separators, Base returns a single separator.
+func Base(path string) string {
+       if path == "" {
+               return "."
+       }
+       // Strip trailing slashes.
+       for len(path) > 0 && isSeparator(path[len(path)-1]) {
+               path = path[0 : len(path)-1]
+       }
+       // Find the last element
+       i := len(path) - 1
+       for i >= 0 && !isSeparator(path[i]) {
+               i--
+       }
+       if i >= 0 {
+               path = path[i+1:]
+       }
+       // If empty now, it had only slashes.
+       if path == "" {
+               return SeparatorString
+       }
+       return path
+}
diff --git a/libgo/go/path/filepath/path_test.go b/libgo/go/path/filepath/path_test.go
new file mode 100644 (file)
index 0000000..0249af4
--- /dev/null
@@ -0,0 +1,486 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package filepath_test
+
+import (
+       "os"
+       "path/filepath"
+       "reflect"
+       "runtime"
+       "testing"
+)
+
+type PathTest struct {
+       path, result string
+}
+
+var cleantests = []PathTest{
+       // Already clean
+       {"", "."},
+       {"abc", "abc"},
+       {"abc/def", "abc/def"},
+       {"a/b/c", "a/b/c"},
+       {".", "."},
+       {"..", ".."},
+       {"../..", "../.."},
+       {"../../abc", "../../abc"},
+       {"/abc", "/abc"},
+       {"/", "/"},
+
+       // Remove trailing slash
+       {"abc/", "abc"},
+       {"abc/def/", "abc/def"},
+       {"a/b/c/", "a/b/c"},
+       {"./", "."},
+       {"../", ".."},
+       {"../../", "../.."},
+       {"/abc/", "/abc"},
+
+       // Remove doubled slash
+       {"abc//def//ghi", "abc/def/ghi"},
+       {"//abc", "/abc"},
+       {"///abc", "/abc"},
+       {"//abc//", "/abc"},
+       {"abc//", "abc"},
+
+       // Remove . elements
+       {"abc/./def", "abc/def"},
+       {"/./abc/def", "/abc/def"},
+       {"abc/.", "abc"},
+
+       // Remove .. elements
+       {"abc/def/ghi/../jkl", "abc/def/jkl"},
+       {"abc/def/../ghi/../jkl", "abc/jkl"},
+       {"abc/def/..", "abc"},
+       {"abc/def/../..", "."},
+       {"/abc/def/../..", "/"},
+       {"abc/def/../../..", ".."},
+       {"/abc/def/../../..", "/"},
+       {"abc/def/../../../ghi/jkl/../../../mno", "../../mno"},
+
+       // Combinations
+       {"abc/./../def", "def"},
+       {"abc//./../def", "def"},
+       {"abc/../../././../def", "../../def"},
+}
+
+func TestClean(t *testing.T) {
+       for _, test := range cleantests {
+               if s := filepath.ToSlash(filepath.Clean(test.path)); s != test.result {
+                       t.Errorf("Clean(%q) = %q, want %q", test.path, s, test.result)
+               }
+       }
+}
+
+const sep = filepath.Separator
+
+var slashtests = []PathTest{
+       {"", ""},
+       {"/", string(sep)},
+       {"/a/b", string([]byte{sep, 'a', sep, 'b'})},
+       {"a//b", string([]byte{'a', sep, sep, 'b'})},
+}
+
+func TestFromAndToSlash(t *testing.T) {
+       for _, test := range slashtests {
+               if s := filepath.FromSlash(test.path); s != test.result {
+                       t.Errorf("FromSlash(%q) = %q, want %q", test.path, s, test.result)
+               }
+               if s := filepath.ToSlash(test.result); s != test.path {
+                       t.Errorf("ToSlash(%q) = %q, want %q", test.result, s, test.path)
+               }
+       }
+}
+
+type SplitListTest struct {
+       list   string
+       result []string
+}
+
+const lsep = filepath.ListSeparator
+
+var splitlisttests = []SplitListTest{
+       {"", []string{}},
+       {string([]byte{'a', lsep, 'b'}), []string{"a", "b"}},
+       {string([]byte{lsep, 'a', lsep, 'b'}), []string{"", "a", "b"}},
+}
+
+func TestSplitList(t *testing.T) {
+       for _, test := range splitlisttests {
+               if l := filepath.SplitList(test.list); !reflect.DeepEqual(l, test.result) {
+                       t.Errorf("SplitList(%q) = %s, want %s", test.list, l, test.result)
+               }
+       }
+}
+
+type SplitTest struct {
+       path, dir, file string
+}
+
+var unixsplittests = []SplitTest{
+       {"a/b", "a/", "b"},
+       {"a/b/", "a/b/", ""},
+       {"a/", "a/", ""},
+       {"a", "", "a"},
+       {"/", "/", ""},
+}
+
+func TestSplit(t *testing.T) {
+       var splittests []SplitTest
+       splittests = unixsplittests
+       for _, test := range splittests {
+               if d, f := filepath.Split(test.path); d != test.dir || f != test.file {
+                       t.Errorf("Split(%q) = %q, %q, want %q, %q", test.path, d, f, test.dir, test.file)
+               }
+       }
+}
+
+type JoinTest struct {
+       elem []string
+       path string
+}
+
+var jointests = []JoinTest{
+       // zero parameters
+       {[]string{}, ""},
+
+       // one parameter
+       {[]string{""}, ""},
+       {[]string{"a"}, "a"},
+
+       // two parameters
+       {[]string{"a", "b"}, "a/b"},
+       {[]string{"a", ""}, "a"},
+       {[]string{"", "b"}, "b"},
+       {[]string{"/", "a"}, "/a"},
+       {[]string{"/", ""}, "/"},
+       {[]string{"a/", "b"}, "a/b"},
+       {[]string{"a/", ""}, "a"},
+       {[]string{"", ""}, ""},
+}
+
+var winjointests = []JoinTest{
+       {[]string{`directory`, `file`}, `directory\file`},
+       {[]string{`C:\Windows\`, `System32`}, `C:\Windows\System32`},
+       {[]string{`C:\Windows\`, ``}, `C:\Windows`},
+       {[]string{`C:\`, `Windows`}, `C:\Windows`},
+       {[]string{`C:`, `Windows`}, `C:\Windows`},
+}
+
+// join takes a []string and passes it to Join.
+func join(elem []string, args ...string) string {
+       args = elem
+       return filepath.Join(args...)
+}
+
+func TestJoin(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               jointests = append(jointests, winjointests...)
+       }
+       for _, test := range jointests {
+               if p := join(test.elem); p != filepath.FromSlash(test.path) {
+                       t.Errorf("join(%q) = %q, want %q", test.elem, p, test.path)
+               }
+       }
+}
+
+type ExtTest struct {
+       path, ext string
+}
+
+var exttests = []ExtTest{
+       {"path.go", ".go"},
+       {"path.pb.go", ".go"},
+       {"a.dir/b", ""},
+       {"a.dir/b.go", ".go"},
+       {"a.dir/", ""},
+}
+
+func TestExt(t *testing.T) {
+       for _, test := range exttests {
+               if x := filepath.Ext(test.path); x != test.ext {
+                       t.Errorf("Ext(%q) = %q, want %q", test.path, x, test.ext)
+               }
+       }
+}
+
+type Node struct {
+       name    string
+       entries []*Node // nil if the entry is a file
+       mark    int
+}
+
+var tree = &Node{
+       "testdata",
+       []*Node{
+               &Node{"a", nil, 0},
+               &Node{"b", []*Node{}, 0},
+               &Node{"c", nil, 0},
+               &Node{
+                       "d",
+                       []*Node{
+                               &Node{"x", nil, 0},
+                               &Node{"y", []*Node{}, 0},
+                               &Node{
+                                       "z",
+                                       []*Node{
+                                               &Node{"u", nil, 0},
+                                               &Node{"v", nil, 0},
+                                       },
+                                       0,
+                               },
+                       },
+                       0,
+               },
+       },
+       0,
+}
+
+func walkTree(n *Node, path string, f func(path string, n *Node)) {
+       f(path, n)
+       for _, e := range n.entries {
+               walkTree(e, filepath.Join(path, e.name), f)
+       }
+}
+
+func makeTree(t *testing.T) {
+       walkTree(tree, tree.name, func(path string, n *Node) {
+               if n.entries == nil {
+                       fd, err := os.Open(path, os.O_CREAT, 0660)
+                       if err != nil {
+                               t.Errorf("makeTree: %v", err)
+                       }
+                       fd.Close()
+               } else {
+                       os.Mkdir(path, 0770)
+               }
+       })
+}
+
+func markTree(n *Node) { walkTree(n, "", func(path string, n *Node) { n.mark++ }) }
+
+func checkMarks(t *testing.T) {
+       walkTree(tree, tree.name, func(path string, n *Node) {
+               if n.mark != 1 {
+                       t.Errorf("node %s mark = %d; expected 1", path, n.mark)
+               }
+               n.mark = 0
+       })
+}
+
+// Assumes that each node name is unique. Good enough for a test.
+func mark(name string) {
+       name = filepath.ToSlash(name)
+       walkTree(tree, tree.name, func(path string, n *Node) {
+               if n.name == name {
+                       n.mark++
+               }
+       })
+}
+
+type TestVisitor struct{}
+
+func (v *TestVisitor) VisitDir(path string, f *os.FileInfo) bool {
+       mark(f.Name)
+       return true
+}
+
+func (v *TestVisitor) VisitFile(path string, f *os.FileInfo) {
+       mark(f.Name)
+}
+
+func TestWalk(t *testing.T) {
+       // TODO(brainman): enable test once Windows version is implemented.
+       if runtime.GOOS == "windows" {
+               return
+       }
+       makeTree(t)
+
+       // 1) ignore error handling, expect none
+       v := &TestVisitor{}
+       filepath.Walk(tree.name, v, nil)
+       checkMarks(t)
+
+       // 2) handle errors, expect none
+       errors := make(chan os.Error, 64)
+       filepath.Walk(tree.name, v, errors)
+       select {
+       case err := <-errors:
+               t.Errorf("no error expected, found: %s", err)
+       default:
+               // ok
+       }
+       checkMarks(t)
+
+       if os.Getuid() > 0 {
+               // introduce 2 errors: chmod top-level directories to 0
+               os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0)
+               os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0)
+               // mark respective subtrees manually
+               markTree(tree.entries[1])
+               markTree(tree.entries[3])
+               // correct double-marking of directory itself
+               tree.entries[1].mark--
+               tree.entries[3].mark--
+
+               // 3) handle errors, expect two
+               errors = make(chan os.Error, 64)
+               os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0)
+               filepath.Walk(tree.name, v, errors)
+       Loop:
+               for i := 1; i <= 2; i++ {
+                       select {
+                       case <-errors:
+                               // ok
+                       default:
+                               t.Errorf("%d. error expected, none found", i)
+                               break Loop
+                       }
+               }
+               select {
+               case err := <-errors:
+                       t.Errorf("only two errors expected, found 3rd: %v", err)
+               default:
+                       // ok
+               }
+               // the inaccessible subtrees were marked manually
+               checkMarks(t)
+       }
+
+       // cleanup
+       os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770)
+       os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770)
+       if err := os.RemoveAll(tree.name); err != nil {
+               t.Errorf("removeTree: %v", err)
+       }
+}
+
+var basetests = []PathTest{
+       {"", "."},
+       {".", "."},
+       {"/.", "."},
+       {"/", "/"},
+       {"////", "/"},
+       {"x/", "x"},
+       {"abc", "abc"},
+       {"abc/def", "def"},
+       {"a/b/.x", ".x"},
+       {"a/b/c.", "c."},
+       {"a/b/c.x", "c.x"},
+}
+
+func TestBase(t *testing.T) {
+       for _, test := range basetests {
+               if s := filepath.ToSlash(filepath.Base(test.path)); s != test.result {
+                       t.Errorf("Base(%q) = %q, want %q", test.path, s, test.result)
+               }
+       }
+}
+
+type IsAbsTest struct {
+       path  string
+       isAbs bool
+}
+
+var isabstests = []IsAbsTest{
+       {"", false},
+       {"/", true},
+       {"/usr/bin/gcc", true},
+       {"..", false},
+       {"/a/../bb", true},
+       {".", false},
+       {"./", false},
+       {"lala", false},
+}
+
+var winisabstests = []IsAbsTest{
+       {`C:\`, true},
+       {`c\`, false},
+       {`c::`, false},
+       {`/`, true},
+       {`\`, true},
+       {`\Windows`, true},
+}
+
+func TestIsAbs(t *testing.T) {
+       if runtime.GOOS == "windows" {
+               isabstests = append(isabstests, winisabstests...)
+       }
+       for _, test := range isabstests {
+               if r := filepath.IsAbs(test.path); r != test.isAbs {
+                       t.Errorf("IsAbs(%q) = %v, want %v", test.path, r, test.isAbs)
+               }
+       }
+}
+
+type EvalSymlinksTest struct {
+       path, dest string
+}
+
+var EvalSymlinksTestDirs = []EvalSymlinksTest{
+       {"test", ""},
+       {"test/dir", ""},
+       {"test/dir/link3", "../../"},
+       {"test/link1", "../test"},
+       {"test/link2", "dir"},
+}
+
+var EvalSymlinksTests = []EvalSymlinksTest{
+       {"test", "test"},
+       {"test/dir", "test/dir"},
+       {"test/dir/../..", "."},
+       {"test/link1", "test"},
+       {"test/link2", "test/dir"},
+       {"test/link1/dir", "test/dir"},
+       {"test/link2/..", "test"},
+       {"test/dir/link3", "."},
+       {"test/link2/link3/test", "test"},
+}
+
+func TestEvalSymlinks(t *testing.T) {
+       // Symlinks are not supported under windows.
+       if runtime.GOOS == "windows" {
+               return
+       }
+       defer os.RemoveAll("test")
+       for _, d := range EvalSymlinksTestDirs {
+               var err os.Error
+               if d.dest == "" {
+                       err = os.Mkdir(d.path, 0755)
+               } else {
+                       err = os.Symlink(d.dest, d.path)
+               }
+               if err != nil {
+                       t.Fatal(err)
+               }
+       }
+       // relative
+       for _, d := range EvalSymlinksTests {
+               if p, err := filepath.EvalSymlinks(d.path); err != nil {
+                       t.Errorf("EvalSymlinks(%v) error: %v", d.path, err)
+               } else if p != d.dest {
+                       t.Errorf("EvalSymlinks(%v)=%v, want %v", d.path, p, d.dest)
+               }
+       }
+       // absolute
+/* These tests do not work in the gccgo test environment.
+       goroot, err := filepath.EvalSymlinks(os.Getenv("GOROOT"))
+       if err != nil {
+               t.Fatalf("EvalSymlinks(%q) error: %v", os.Getenv("GOROOT"), err)
+       }
+       testroot := filepath.Join(goroot, "src", "pkg", "path", "filepath")
+       for _, d := range EvalSymlinksTests {
+               a := EvalSymlinksTest{
+                       filepath.Join(testroot, d.path),
+                       filepath.Join(testroot, d.dest),
+               }
+               if p, err := filepath.EvalSymlinks(a.path); err != nil {
+                       t.Errorf("EvalSymlinks(%v) error: %v", a.path, err)
+               } else if p != a.dest {
+                       t.Errorf("EvalSymlinks(%v)=%v, want %v", a.path, p, a.dest)
+               }
+       }
+*/
+}
diff --git a/libgo/go/path/filepath/path_unix.go b/libgo/go/path/filepath/path_unix.go
new file mode 100644 (file)
index 0000000..1bb21ec
--- /dev/null
@@ -0,0 +1,28 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package filepath
+
+import "strings"
+
+const (
+       Separator     = '/' // OS-specific path separator
+       ListSeparator = ':' // OS-specific path list separator
+)
+
+// isSeparator returns true if c is a directory separator character.
+func isSeparator(c uint8) bool {
+       return Separator == c
+}
+
+// IsAbs returns true if the path is absolute.
+func IsAbs(path string) bool {
+       return strings.HasPrefix(path, "/")
+}
+
+// volumeName returns the leading volume name on Windows.
+// It returns "" on Unix.
+func volumeName(path string) string {
+       return ""
+}
diff --git a/libgo/go/path/filepath/path_windows.go b/libgo/go/path/filepath/path_windows.go
new file mode 100644 (file)
index 0000000..dbd1c1e
--- /dev/null
@@ -0,0 +1,37 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package filepath
+
+const (
+       Separator     = '\\' // OS-specific path separator
+       ListSeparator = ':'  // OS-specific path list separator
+)
+
+// isSeparator returns true if c is a directory separator character.
+func isSeparator(c uint8) bool {
+       // NOTE: Windows accept / as path separator.
+       return c == '\\' || c == '/'
+}
+
+// IsAbs returns true if the path is absolute.
+func IsAbs(path string) bool {
+       return path != "" && (volumeName(path) != "" || isSeparator(path[0]))
+}
+
+// volumeName return leading volume name.  
+// If given "C:\foo\bar", return "C:" on windows.
+func volumeName(path string) string {
+       if path == "" {
+               return ""
+       }
+       // with drive letter
+       c := path[0]
+       if len(path) > 2 && path[1] == ':' && isSeparator(path[2]) &&
+               ('0' <= c && c <= '9' || 'a' <= c && c <= 'z' ||
+                       'A' <= c && c <= 'Z') {
+               return path[0:2]
+       }
+       return ""
+}
index dd3422c4256280b7f66aa698851e8056dec6903e..efb8c5ce7fcf075d81ce04cb3101b10eadc2f298 100644 (file)
@@ -1,8 +1,11 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package path
 
 import (
        "os"
-       "sort"
        "strings"
        "utf8"
 )
@@ -10,7 +13,7 @@ import (
 var ErrBadPattern = os.NewError("syntax error in pattern")
 
 // Match returns true if name matches the shell file name pattern.
-// The syntax used by pattern is:
+// The pattern syntax is:
 //
 //     pattern:
 //             { term }
@@ -75,7 +78,7 @@ Pattern:
        return len(name) == 0, nil
 }
 
-// scanChunk gets the next section of pattern, which is a non-star string
+// scanChunk gets the next segment of pattern, which is a non-star string
 // possibly preceded by a star.
 func scanChunk(pattern string) (star bool, chunk, rest string) {
        for len(pattern) > 0 && pattern[0] == '*' {
@@ -92,7 +95,6 @@ Scan:
                        if i+1 < len(pattern) {
                                i++
                        }
-                       continue
                case '[':
                        inrange = true
                case ']':
@@ -203,76 +205,3 @@ func getEsc(chunk string) (r int, nchunk string, err os.Error) {
        }
        return
 }
-
-// Glob returns the names of all files matching pattern or nil
-// if there is no matching file. The syntax of patterns is the same
-// as in Match. The pattern may describe hierarchical names such as
-// /usr/*/bin/ed.
-//
-func Glob(pattern string) (matches []string) {
-       if !hasMeta(pattern) {
-               if _, err := os.Stat(pattern); err == nil {
-                       return []string{pattern}
-               }
-               return nil
-       }
-
-       dir, file := Split(pattern)
-       switch dir {
-       case "":
-               dir = "."
-       case "/":
-               // nothing
-       default:
-               dir = dir[0 : len(dir)-1] // chop off trailing '/'
-       }
-
-       if hasMeta(dir) {
-               for _, d := range Glob(dir) {
-                       matches = glob(d, file, matches)
-               }
-       } else {
-               return glob(dir, file, nil)
-       }
-       return matches
-}
-
-// glob searches for files matching pattern in the directory dir
-// and appends them to matches.
-func glob(dir, pattern string, matches []string) []string {
-       fi, err := os.Stat(dir)
-       if err != nil {
-               return nil
-       }
-       if !fi.IsDirectory() {
-               return matches
-       }
-       d, err := os.Open(dir, os.O_RDONLY, 0666)
-       if err != nil {
-               return nil
-       }
-       defer d.Close()
-
-       names, err := d.Readdirnames(-1)
-       if err != nil {
-               return nil
-       }
-       sort.SortStrings(names)
-
-       for _, n := range names {
-               matched, err := Match(pattern, n)
-               if err != nil {
-                       return matches
-               }
-               if matched {
-                       matches = append(matches, Join(dir, n))
-               }
-       }
-       return matches
-}
-
-// hasMeta returns true if path contains any of the magic characters
-// recognized by Match.
-func hasMeta(path string) bool {
-       return strings.IndexAny(path, "*?[") != -1
-}
index 141cff7da01d1f40ce76f6454d1f33c15f783ebb..f377f1083b77a25bf947d8859ae8a5fef73dde5f 100644 (file)
@@ -75,32 +75,3 @@ func TestMatch(t *testing.T) {
                }
        }
 }
-
-// contains returns true if vector contains the string s.
-func contains(vector []string, s string) bool {
-       for _, elem := range vector {
-               if elem == s {
-                       return true
-               }
-       }
-       return false
-}
-
-var globTests = []struct {
-       pattern, result string
-}{
-       {"match.go", "match.go"},
-       {"mat?h.go", "match.go"},
-       {"*", "match.go"},
-       // Fails in the gccgo test environment.
-       // {"../*/match.go", "../path/match.go"},
-}
-
-func TestGlob(t *testing.T) {
-       for _, tt := range globTests {
-               matches := Glob(tt.pattern)
-               if !contains(matches, tt.result) {
-                       t.Errorf("Glob(%#q) = %#v want %v", tt.pattern, matches, tt.result)
-               }
-       }
-}
index 61eea88588bd98baa94d9ebfc685c787c3ee188c..658eec0938706dfd3b6bab90bb77f30a71d89cc6 100644 (file)
@@ -7,8 +7,6 @@
 package path
 
 import (
-       "io/ioutil"
-       "os"
        "strings"
 )
 
@@ -107,7 +105,7 @@ func Clean(path string) string {
 // If there is no separator in path, Split returns an empty dir and
 // file set to path.
 func Split(path string) (dir, file string) {
-       i := strings.LastIndexAny(path, PathSeps)
+       i := strings.LastIndex(path, "/")
        return path[:i+1], path[i+1:]
 }
 
@@ -135,78 +133,30 @@ func Ext(path string) string {
        return ""
 }
 
-// Visitor methods are invoked for corresponding file tree entries
-// visited by Walk. The parameter path is the full path of f relative
-// to root.
-type Visitor interface {
-       VisitDir(path string, f *os.FileInfo) bool
-       VisitFile(path string, f *os.FileInfo)
-}
-
-func walk(path string, f *os.FileInfo, v Visitor, errors chan<- os.Error) {
-       if !f.IsDirectory() {
-               v.VisitFile(path, f)
-               return
-       }
-
-       if !v.VisitDir(path, f) {
-               return // skip directory entries
-       }
-
-       list, err := ioutil.ReadDir(path)
-       if err != nil {
-               if errors != nil {
-                       errors <- err
-               }
-       }
-
-       for _, e := range list {
-               walk(Join(path, e.Name), e, v, errors)
-       }
-}
-
-// Walk walks the file tree rooted at root, calling v.VisitDir or
-// v.VisitFile for each directory or file in the tree, including root.
-// If v.VisitDir returns false, Walk skips the directory's entries;
-// otherwise it invokes itself for each directory entry in sorted order.
-// An error reading a directory does not abort the Walk.
-// If errors != nil, Walk sends each directory read error
-// to the channel.  Otherwise Walk discards the error.
-func Walk(root string, v Visitor, errors chan<- os.Error) {
-       f, err := os.Lstat(root)
-       if err != nil {
-               if errors != nil {
-                       errors <- err
-               }
-               return // can't progress
-       }
-       walk(root, f, v, errors)
-}
-
-// Base returns the last path element of the slash-separated name.
-// Trailing slashes are removed before extracting the last element.  If the name is
-// empty, "." is returned.  If it consists entirely of slashes, "/" is returned.
-func Base(name string) string {
-       if name == "" {
+// Base returns the last element of path.
+// Trailing slashes are removed before extracting the last element.
+// If the path is empty, Base returns ".".
+// If the path consists entirely of slashes, Base returns "/".
+func Base(path string) string {
+       if path == "" {
                return "."
        }
        // Strip trailing slashes.
-       for len(name) > 0 && name[len(name)-1] == '/' {
-               name = name[0 : len(name)-1]
+       for len(path) > 0 && path[len(path)-1] == '/' {
+               path = path[0 : len(path)-1]
        }
        // Find the last element
-       if i := strings.LastIndex(name, "/"); i >= 0 {
-               name = name[i+1:]
+       if i := strings.LastIndex(path, "/"); i >= 0 {
+               path = path[i+1:]
        }
        // If empty now, it had only slashes.
-       if name == "" {
+       if path == "" {
                return "/"
        }
-       return name
+       return path
 }
 
 // IsAbs returns true if the path is absolute.
 func IsAbs(path string) bool {
-       // TODO: Add Windows support
-       return strings.HasPrefix(path, "/")
+       return len(path) > 0 && path[0] == '/'
 }
index ab0b48ad6ad94a3cf470104c84340cf5119e9a6b..1fd57cc800ec7eccff173f03239463f2bfc81c0c 100644 (file)
@@ -5,8 +5,6 @@
 package path
 
 import (
-       "os"
-       "runtime"
        "testing"
 )
 
@@ -84,18 +82,7 @@ var splittests = []SplitTest{
        {"/", "/", ""},
 }
 
-var winsplittests = []SplitTest{
-       {`C:\Windows\System32`, `C:\Windows\`, `System32`},
-       {`C:\Windows\`, `C:\Windows\`, ``},
-       {`C:\Windows`, `C:\`, `Windows`},
-       {`C:Windows`, `C:`, `Windows`},
-       {`\\?\c:\`, `\\?\c:\`, ``},
-}
-
 func TestSplit(t *testing.T) {
-       if runtime.GOOS == "windows" {
-               splittests = append(splittests, winsplittests...)
-       }
        for _, test := range splittests {
                if d, f := Split(test.path); d != test.dir || f != test.file {
                        t.Errorf("Split(%q) = %q, %q, want %q, %q", test.path, d, f, test.dir, test.file)
@@ -161,152 +148,6 @@ func TestExt(t *testing.T) {
        }
 }
 
-type Node struct {
-       name    string
-       entries []*Node // nil if the entry is a file
-       mark    int
-}
-
-var tree = &Node{
-       "testdata",
-       []*Node{
-               &Node{"a", nil, 0},
-               &Node{"b", []*Node{}, 0},
-               &Node{"c", nil, 0},
-               &Node{
-                       "d",
-                       []*Node{
-                               &Node{"x", nil, 0},
-                               &Node{"y", []*Node{}, 0},
-                               &Node{
-                                       "z",
-                                       []*Node{
-                                               &Node{"u", nil, 0},
-                                               &Node{"v", nil, 0},
-                                       },
-                                       0,
-                               },
-                       },
-                       0,
-               },
-       },
-       0,
-}
-
-func walkTree(n *Node, path string, f func(path string, n *Node)) {
-       f(path, n)
-       for _, e := range n.entries {
-               walkTree(e, Join(path, e.name), f)
-       }
-}
-
-func makeTree(t *testing.T) {
-       walkTree(tree, tree.name, func(path string, n *Node) {
-               if n.entries == nil {
-                       fd, err := os.Open(path, os.O_CREAT, 0660)
-                       if err != nil {
-                               t.Errorf("makeTree: %v", err)
-                       }
-                       fd.Close()
-               } else {
-                       os.Mkdir(path, 0770)
-               }
-       })
-}
-
-func markTree(n *Node) { walkTree(n, "", func(path string, n *Node) { n.mark++ }) }
-
-func checkMarks(t *testing.T) {
-       walkTree(tree, tree.name, func(path string, n *Node) {
-               if n.mark != 1 {
-                       t.Errorf("node %s mark = %d; expected 1", path, n.mark)
-               }
-               n.mark = 0
-       })
-}
-
-// Assumes that each node name is unique. Good enough for a test.
-func mark(name string) {
-       walkTree(tree, tree.name, func(path string, n *Node) {
-               if n.name == name {
-                       n.mark++
-               }
-       })
-}
-
-type TestVisitor struct{}
-
-func (v *TestVisitor) VisitDir(path string, f *os.FileInfo) bool {
-       mark(f.Name)
-       return true
-}
-
-func (v *TestVisitor) VisitFile(path string, f *os.FileInfo) {
-       mark(f.Name)
-}
-
-func TestWalk(t *testing.T) {
-       makeTree(t)
-
-       // 1) ignore error handling, expect none
-       v := &TestVisitor{}
-       Walk(tree.name, v, nil)
-       checkMarks(t)
-
-       // 2) handle errors, expect none
-       errors := make(chan os.Error, 64)
-       Walk(tree.name, v, errors)
-       select {
-       case err := <-errors:
-               t.Errorf("no error expected, found: %s", err)
-       default:
-               // ok
-       }
-       checkMarks(t)
-
-       if os.Getuid() != 0 {
-               // introduce 2 errors: chmod top-level directories to 0
-               os.Chmod(Join(tree.name, tree.entries[1].name), 0)
-               os.Chmod(Join(tree.name, tree.entries[3].name), 0)
-               // mark respective subtrees manually
-               markTree(tree.entries[1])
-               markTree(tree.entries[3])
-               // correct double-marking of directory itself
-               tree.entries[1].mark--
-               tree.entries[3].mark--
-
-               // 3) handle errors, expect two
-               errors = make(chan os.Error, 64)
-               os.Chmod(Join(tree.name, tree.entries[1].name), 0)
-               Walk(tree.name, v, errors)
-       Loop:
-               for i := 1; i <= 2; i++ {
-                       select {
-                       case <-errors:
-                               // ok
-                       default:
-                               t.Errorf("%d. error expected, none found", i)
-                               break Loop
-                       }
-               }
-               select {
-               case err := <-errors:
-                       t.Errorf("only two errors expected, found 3rd: %v", err)
-               default:
-                       // ok
-               }
-               // the inaccessible subtrees were marked manually
-               checkMarks(t)
-       }
-
-       // cleanup
-       os.Chmod(Join(tree.name, tree.entries[1].name), 0770)
-       os.Chmod(Join(tree.name, tree.entries[3].name), 0770)
-       if err := os.RemoveAll(tree.name); err != nil {
-               t.Errorf("removeTree: %v", err)
-       }
-}
-
 var basetests = []CleanTest{
        // Already clean
        {"", "."},
diff --git a/libgo/go/path/path_unix.go b/libgo/go/path/path_unix.go
deleted file mode 100644 (file)
index 7e8c5eb..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package path
-
-const (
-       DirSeps    = `/`                  // directory separators
-       VolumeSeps = ``                   // volume separators
-       PathSeps   = DirSeps + VolumeSeps // all path separators
-)
diff --git a/libgo/go/path/path_windows.go b/libgo/go/path/path_windows.go
deleted file mode 100644 (file)
index 966eb49..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-// Copyright 2010 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package path
-
-const (
-       DirSeps    = `\/`                 // directory separators
-       VolumeSeps = `:`                  // volume separators
-       PathSeps   = DirSeps + VolumeSeps // all path separators
-)
index 42afadd85c08cbcdb04229ece2e8df5362307b8e..ac07ce5a36836fbbf8016832cb9323ffab7ac3c6 100644 (file)
@@ -968,28 +968,28 @@ func TestChan(t *testing.T) {
 
                // Recv
                c <- 3
-               if i := cv.Recv().(*IntValue).Get(); i != 3 {
-                       t.Errorf("native send 3, reflect Recv %d", i)
+               if i, ok := cv.Recv(); i.(*IntValue).Get() != 3 || !ok {
+                       t.Errorf("native send 3, reflect Recv %d, %t", i.(*IntValue).Get(), ok)
                }
 
                // TryRecv fail
-               val := cv.TryRecv()
-               if val != nil {
-                       t.Errorf("TryRecv on empty chan: %s", valueToString(val))
+               val, ok := cv.TryRecv()
+               if val != nil || ok {
+                       t.Errorf("TryRecv on empty chan: %s, %t", valueToString(val), ok)
                }
 
                // TryRecv success
                c <- 4
-               val = cv.TryRecv()
+               val, ok = cv.TryRecv()
                if val == nil {
                        t.Errorf("TryRecv on ready chan got nil")
-               } else if i := val.(*IntValue).Get(); i != 4 {
-                       t.Errorf("native send 4, TryRecv %d", i)
+               } else if i := val.(*IntValue).Get(); i != 4 || !ok {
+                       t.Errorf("native send 4, TryRecv %d, %t", i, ok)
                }
 
                // TrySend fail
                c <- 100
-               ok := cv.TrySend(NewValue(5))
+               ok = cv.TrySend(NewValue(5))
                i := <-c
                if ok {
                        t.Errorf("TrySend on full chan succeeded: value %d", i)
@@ -1008,20 +1008,11 @@ func TestChan(t *testing.T) {
                // Close
                c <- 123
                cv.Close()
-               if cv.Closed() {
-                       t.Errorf("closed too soon - 1")
+               if i, ok := cv.Recv(); i.(*IntValue).Get() != 123 || !ok {
+                       t.Errorf("send 123 then close; Recv %d, %t", i.(*IntValue).Get(), ok)
                }
-               if i := cv.Recv().(*IntValue).Get(); i != 123 {
-                       t.Errorf("send 123 then close; Recv %d", i)
-               }
-               if cv.Closed() {
-                       t.Errorf("closed too soon - 2")
-               }
-               if i := cv.Recv().(*IntValue).Get(); i != 0 {
-                       t.Errorf("after close Recv %d", i)
-               }
-               if !cv.Closed() {
-                       t.Errorf("not closed")
+               if i, ok := cv.Recv(); i.(*IntValue).Get() != 0 || ok {
+                       t.Errorf("after close Recv %d, %t", i.(*IntValue).Get(), ok)
                }
        }
 
@@ -1032,7 +1023,7 @@ func TestChan(t *testing.T) {
        if cv.TrySend(NewValue(7)) {
                t.Errorf("TrySend on sync chan succeeded")
        }
-       if cv.TryRecv() != nil {
+       if v, ok := cv.TryRecv(); v != nil || ok {
                t.Errorf("TryRecv on sync chan succeeded")
        }
 
index e7b68b3e720b2479117064736f05d04c0a7f6ca8..ebc87d45b9226703b5a2adc3f2783f7265f4137b 100644 (file)
@@ -671,19 +671,12 @@ func (v *ChanValue) Get() uintptr { return *(*uintptr)(v.addr) }
 
 // implemented in ../pkg/runtime/reflect.cgo
 func makechan(typ *runtime.ChanType, size uint32) (ch *byte)
-func chansend(ch, val *byte, pres *bool)
-func chanrecv(ch, val *byte, pres *bool)
-func chanclosed(ch *byte) bool
+func chansend(ch, val *byte, selected *bool)
+func chanrecv(ch, val *byte, selected *bool, ok *bool)
 func chanclose(ch *byte)
 func chanlen(ch *byte) int32
 func chancap(ch *byte) int32
 
-// Closed returns the result of closed(c) on the underlying channel.
-func (v *ChanValue) Closed() bool {
-       ch := *(**byte)(v.addr)
-       return chanclosed(ch)
-}
-
 // Close closes the channel.
 func (v *ChanValue) Close() {
        ch := *(**byte)(v.addr)
@@ -700,52 +693,61 @@ func (v *ChanValue) Cap() int {
        return int(chancap(ch))
 }
 
-// internal send; non-blocking if b != nil
-func (v *ChanValue) send(x Value, b *bool) {
+// internal send; non-blocking if selected != nil
+func (v *ChanValue) send(x Value, selected *bool) {
        t := v.Type().(*ChanType)
        if t.Dir()&SendDir == 0 {
                panic("send on recv-only channel")
        }
        typesMustMatch(t.Elem(), x.Type())
        ch := *(**byte)(v.addr)
-       chansend(ch, (*byte)(x.getAddr()), b)
+       chansend(ch, (*byte)(x.getAddr()), selected)
 }
 
-// internal recv; non-blocking if b != nil
-func (v *ChanValue) recv(b *bool) Value {
+// internal recv; non-blocking if selected != nil
+func (v *ChanValue) recv(selected *bool) (Value, bool) {
        t := v.Type().(*ChanType)
        if t.Dir()&RecvDir == 0 {
                panic("recv on send-only channel")
        }
        ch := *(**byte)(v.addr)
        x := MakeZero(t.Elem())
-       chanrecv(ch, (*byte)(x.getAddr()), b)
-       return x
+       var ok bool
+       chanrecv(ch, (*byte)(x.getAddr()), selected, &ok)
+       return x, ok
 }
 
 // Send sends x on the channel v.
 func (v *ChanValue) Send(x Value) { v.send(x, nil) }
 
 // Recv receives and returns a value from the channel v.
-func (v *ChanValue) Recv() Value { return v.recv(nil) }
+// The receive blocks until a value is ready.
+// The boolean value ok is true if the value x corresponds to a send
+// on the channel, false if it is a zero value received because the channel is closed.
+func (v *ChanValue) Recv() (x Value, ok bool) {
+       return v.recv(nil)
+}
 
 // TrySend attempts to sends x on the channel v but will not block.
 // It returns true if the value was sent, false otherwise.
 func (v *ChanValue) TrySend(x Value) bool {
-       var ok bool
-       v.send(x, &ok)
-       return ok
+       var selected bool
+       v.send(x, &selected)
+       return selected
 }
 
 // TryRecv attempts to receive a value from the channel v but will not block.
-// It returns the value if one is received, nil otherwise.
-func (v *ChanValue) TryRecv() Value {
-       var ok bool
-       x := v.recv(&ok)
-       if !ok {
-               return nil
-       }
-       return x
+// If the receive cannot finish without blocking, TryRecv instead returns x == nil.
+// If the receive can finish without blocking, TryRecv returns x != nil.
+// The boolean value ok is true if the value x corresponds to a send
+// on the channel, false if it is a zero value received because the channel is closed.
+func (v *ChanValue) TryRecv() (x Value, ok bool) {
+       var selected bool
+       x, ok = v.recv(&selected)
+       if !selected {
+               return nil, false
+       }
+       return x, ok
 }
 
 // MakeChan creates a new channel with the specified type and buffer size.
index 6de6d1325b6388891f65be0b7af24f4c78f5e967..92372521175f81c4c1e991a5b6421f88b5b2e004 100644 (file)
@@ -39,8 +39,9 @@ type Call struct {
 // There may be multiple outstanding Calls associated
 // with a single Client.
 type Client struct {
-       mutex    sync.Mutex // protects pending, seq
+       mutex    sync.Mutex // protects pending, seq, request
        sending  sync.Mutex
+       request  Request
        seq      uint64
        codec    ClientCodec
        pending  map[uint64]*Call
@@ -79,21 +80,21 @@ func (client *Client) send(c *Call) {
        client.mutex.Unlock()
 
        // Encode and send the request.
-       request := new(Request)
        client.sending.Lock()
        defer client.sending.Unlock()
-       request.Seq = c.seq
-       request.ServiceMethod = c.ServiceMethod
-       if err := client.codec.WriteRequest(request, c.Args); err != nil {
+       client.request.Seq = c.seq
+       client.request.ServiceMethod = c.ServiceMethod
+       if err := client.codec.WriteRequest(&client.request, c.Args); err != nil {
                panic("rpc: client encode error: " + err.String())
        }
 }
 
 func (client *Client) input() {
        var err os.Error
+       var response Response
        for err == nil {
-               response := new(Response)
-               err = client.codec.ReadResponseHeader(response)
+               response = Response{}
+               err = client.codec.ReadResponseHeader(&response)
                if err != nil {
                        if err == os.EOF && !client.closing {
                                err = io.ErrUnexpectedEOF
@@ -148,8 +149,12 @@ func (call *Call) done() {
 
 // NewClient returns a new Client to handle requests to the
 // set of services at the other end of the connection.
+// It adds a buffer to the write side of the connection so
+// the header and payload are sent as a unit.
 func NewClient(conn io.ReadWriteCloser) *Client {
-       return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+       encBuf := bufio.NewWriter(conn)
+       client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
+       return NewClientWithCodec(client)
 }
 
 // NewClientWithCodec is like NewClient but uses the specified
@@ -164,16 +169,20 @@ func NewClientWithCodec(codec ClientCodec) *Client {
 }
 
 type gobClientCodec struct {
-       rwc io.ReadWriteCloser
-       dec *gob.Decoder
-       enc *gob.Encoder
+       rwc    io.ReadWriteCloser
+       dec    *gob.Decoder
+       enc    *gob.Encoder
+       encBuf *bufio.Writer
 }
 
-func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error {
-       if err := c.enc.Encode(r); err != nil {
-               return err
+func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err os.Error) {
+       if err = c.enc.Encode(r); err != nil {
+               return
+       }
+       if err = c.enc.Encode(body); err != nil {
+               return
        }
-       return c.enc.Encode(body)
+       return c.encBuf.Flush()
 }
 
 func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error {
@@ -273,6 +282,6 @@ func (client *Client) Call(serviceMethod string, args interface{}, reply interfa
        if client.shutdown {
                return ErrShutdown
        }
-       call := <-client.Go(serviceMethod, args, reply, nil).Done
+       call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
        return call.Error
 }
index 9dcda41480785faec736fdf071bdde64847e4530..1cc8c3173a82ab161e0e158a69a8f7824b96e998 100644 (file)
 package rpc
 
 import (
+       "bufio"
        "gob"
        "http"
        "log"
@@ -153,29 +154,29 @@ type service struct {
 // but documented here as an aid to debugging, such as when analyzing
 // network traffic.
 type Request struct {
-       ServiceMethod string // format: "Service.Method"
-       Seq           uint64 // sequence number chosen by client
+       ServiceMethod string   // format: "Service.Method"
+       Seq           uint64   // sequence number chosen by client
+       next          *Request // for free list in Server
 }
 
 // Response is a header written before every RPC return.  It is used internally
 // but documented here as an aid to debugging, such as when analyzing
 // network traffic.
 type Response struct {
-       ServiceMethod string // echoes that of the Request
-       Seq           uint64 // echoes that of the request
-       Error         string // error, if any.
-}
-
-// ClientInfo records information about an RPC client connection.
-type ClientInfo struct {
-       LocalAddr  string
-       RemoteAddr string
+       ServiceMethod string    // echoes that of the Request
+       Seq           uint64    // echoes that of the request
+       Error         string    // error, if any.
+       next          *Response // for free list in Server
 }
 
 // Server represents an RPC Server.
 type Server struct {
        sync.Mutex // protects the serviceMap
        serviceMap map[string]*service
+       reqLock    sync.Mutex // protects freeReq
+       freeReq    *Request
+       respLock   sync.Mutex // protects freeResp
+       freeResp   *Response
 }
 
 // NewServer returns a new Server.
@@ -269,13 +270,6 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E
                        log.Println(mname, "reply type not exported:", replyType)
                        continue
                }
-               if mtype.NumIn() == 4 {
-                       t := mtype.In(3)
-                       if t != reflect.Typeof((*ClientInfo)(nil)) {
-                               log.Println(mname, "last argument not *ClientInfo")
-                               continue
-                       }
-               }
                // Method needs one out: os.Error.
                if mtype.NumOut() != 1 {
                        log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
@@ -298,9 +292,7 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E
 }
 
 // A value sent as a placeholder for the response when the server receives an invalid request.
-type InvalidRequest struct {
-       Marker int
-}
+type InvalidRequest struct{}
 
 var invalidRequest = InvalidRequest{}
 
@@ -310,8 +302,8 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
        return v
 }
 
-func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
-       resp := new(Response)
+func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
+       resp := server.getResponse()
        // Encode the response header
        resp.ServiceMethod = req.ServiceMethod
        if errmsg != "" {
@@ -325,6 +317,7 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se
                log.Println("rpc: writing response:", err)
        }
        sending.Unlock()
+       server.freeResponse(resp)
 }
 
 func (m *methodType) NumCalls() (n uint) {
@@ -334,7 +327,7 @@ func (m *methodType) NumCalls() (n uint) {
        return n
 }
 
-func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
        mtype.Lock()
        mtype.numCalls++
        mtype.Unlock()
@@ -347,13 +340,15 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg
        if errInter != nil {
                errmsg = errInter.(os.Error).String()
        }
-       sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+       server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+       server.freeRequest(req)
 }
 
 type gobServerCodec struct {
-       rwc io.ReadWriteCloser
-       dec *gob.Decoder
-       enc *gob.Encoder
+       rwc    io.ReadWriteCloser
+       dec    *gob.Decoder
+       enc    *gob.Encoder
+       encBuf *bufio.Writer
 }
 
 func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
@@ -364,11 +359,14 @@ func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
        return c.dec.Decode(body)
 }
 
-func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error {
-       if err := c.enc.Encode(r); err != nil {
-               return err
+func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err os.Error) {
+       if err = c.enc.Encode(r); err != nil {
+               return
+       }
+       if err = c.enc.Encode(body); err != nil {
+               return
        }
-       return c.enc.Encode(body)
+       return c.encBuf.Flush()
 }
 
 func (c *gobServerCodec) Close() os.Error {
@@ -382,7 +380,9 @@ func (c *gobServerCodec) Close() os.Error {
 // ServeConn uses the gob wire format (see package gob) on the
 // connection.  To use an alternate codec, use ServeCodec.
 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
-       server.ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+       buf := bufio.NewWriter(conn)
+       srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
+       server.ServeCodec(srv)
 }
 
 // ServeCodec is like ServeConn but uses the specified codec to
@@ -403,7 +403,8 @@ func (server *Server) ServeCodec(codec ServerCodec) {
 
                        // send a response if we actually managed to read a header.
                        if req != nil {
-                               sendResponse(sending, req, invalidRequest, codec, err.String())
+                               server.sendResponse(sending, req, invalidRequest, codec, err.String())
+                               server.freeRequest(req)
                        }
                        continue
                }
@@ -419,16 +420,57 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                                }
                                break
                        }
-                       sendResponse(sending, req, replyv.Interface(), codec, err.String())
+                       server.sendResponse(sending, req, replyv.Interface(), codec, err.String())
                        continue
                }
-               go service.call(sending, mtype, req, argv, replyv, codec)
+               go service.call(server, sending, mtype, req, argv, replyv, codec)
        }
        codec.Close()
 }
+
+func (server *Server) getRequest() *Request {
+       server.reqLock.Lock()
+       req := server.freeReq
+       if req == nil {
+               req = new(Request)
+       } else {
+               server.freeReq = req.next
+               *req = Request{}
+       }
+       server.reqLock.Unlock()
+       return req
+}
+
+func (server *Server) freeRequest(req *Request) {
+       server.reqLock.Lock()
+       req.next = server.freeReq
+       server.freeReq = req
+       server.reqLock.Unlock()
+}
+
+func (server *Server) getResponse() *Response {
+       server.respLock.Lock()
+       resp := server.freeResp
+       if resp == nil {
+               resp = new(Response)
+       } else {
+               server.freeResp = resp.next
+               *resp = Response{}
+       }
+       server.respLock.Unlock()
+       return resp
+}
+
+func (server *Server) freeResponse(resp *Response) {
+       server.respLock.Lock()
+       resp.next = server.freeResp
+       server.freeResp = resp
+       server.respLock.Unlock()
+}
+
 func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) {
        // Grab the request header.
-       req = new(Request)
+       req = server.getRequest()
        err = codec.ReadRequestHeader(req)
        if err != nil {
                req = nil
@@ -522,14 +564,14 @@ var connected = "200 Connected to Go RPC"
 // ServeHTTP implements an http.Handler that answers RPC requests.
 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
        if req.Method != "CONNECT" {
-               w.SetHeader("Content-Type", "text/plain; charset=utf-8")
+               w.Header().Set("Content-Type", "text/plain; charset=utf-8")
                w.WriteHeader(http.StatusMethodNotAllowed)
                io.WriteString(w, "405 must CONNECT\n")
                return
        }
-       conn, _, err := w.Hijack()
+       conn, _, err := w.(http.Hijacker).Hijack()
        if err != nil {
-               log.Print("rpc hijacking ", w.RemoteAddr(), ": ", err.String())
+               log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.String())
                return
        }
        io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
index 05aaebceb487bf132a17deeb56efe01afe8f572b..d4041ae70ce8ac406c6fff5eb5c0e04c3aa039cf 100644 (file)
@@ -6,10 +6,11 @@ package rpc
 
 import (
        "fmt"
-       "http"
+       "http/httptest"
        "log"
        "net"
        "os"
+       "runtime"
        "strings"
        "sync"
        "testing"
@@ -103,11 +104,9 @@ func startNewServer() {
 }
 
 func startHttpServer() {
-       var l net.Listener
-       l, httpServerAddr = listenTCP()
-       httpServerAddr = l.Addr().String()
+       server := httptest.NewServer(nil)
+       httpServerAddr = server.Listener.Addr().String()
        log.Println("Test HTTP RPC server listening on", httpServerAddr)
-       go http.Serve(l, nil)
 }
 
 func TestRPC(t *testing.T) {
@@ -313,12 +312,12 @@ func (WriteFailCodec) WriteRequest(*Request, interface{}) os.Error {
 }
 
 func (WriteFailCodec) ReadResponseHeader(*Response) os.Error {
-       time.Sleep(60e9)
+       time.Sleep(120e9)
        panic("unreachable")
 }
 
 func (WriteFailCodec) ReadResponseBody(interface{}) os.Error {
-       time.Sleep(60e9)
+       time.Sleep(120e9)
        panic("unreachable")
 }
 
@@ -351,3 +350,52 @@ func testSendDeadlock(client *Client) {
        reply := new(Reply)
        client.Call("Arith.Add", args, reply)
 }
+
+func TestCountMallocs(t *testing.T) {
+       once.Do(startServer)
+       client, err := Dial("tcp", serverAddr)
+       if err != nil {
+               t.Error("error dialing", err)
+       }
+       args := &Args{7, 8}
+       reply := new(Reply)
+       mallocs := 0 - runtime.MemStats.Mallocs
+       const count = 100
+       for i := 0; i < count; i++ {
+               err = client.Call("Arith.Add", args, reply)
+               if err != nil {
+                       t.Errorf("Add: expected no error but got string %q", err.String())
+               }
+               if reply.C != args.A+args.B {
+                       t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+               }
+       }
+       mallocs += runtime.MemStats.Mallocs
+       fmt.Printf("mallocs per rpc round trip: %d\n", mallocs/count)
+}
+
+func BenchmarkEndToEnd(b *testing.B) {
+       b.StopTimer()
+       once.Do(startServer)
+       client, err := Dial("tcp", serverAddr)
+       if err != nil {
+               fmt.Println("error dialing", err)
+               return
+       }
+
+       // Synchronous calls
+       args := &Args{7, 8}
+       reply := new(Reply)
+       b.StartTimer()
+       for i := 0; i < b.N; i++ {
+               err = client.Call("Arith.Add", args, reply)
+               if err != nil {
+                       fmt.Printf("Add: expected no error but got string %q", err.String())
+                       break
+               }
+               if reply.C != args.A+args.B {
+                       fmt.Printf("Add: expected %d got %d", reply.C, args.A+args.B)
+                       break
+               }
+       }
+}
index 74010b394b069a2b64196d07192de5947052782e..6370a57d8029a31de8b06b83e823d13e0a6d6ca0 100644 (file)
@@ -4,8 +4,6 @@
 
 package runtime
 
-import "unsafe"
-
 // Breakpoint() executes a breakpoint trap.
 func Breakpoint()
 
@@ -31,65 +29,6 @@ func Cgocalls() int64
 // Goroutines returns the number of goroutines that currently exist.
 func Goroutines() int32
 
-type MemStatsType struct {
-       // General statistics.
-       // Not locked during update; approximate.
-       Alloc      uint64 // bytes allocated and still in use
-       TotalAlloc uint64 // bytes allocated (even if freed)
-       Sys        uint64 // bytes obtained from system (should be sum of XxxSys below)
-       Lookups    uint64 // number of pointer lookups
-       Mallocs    uint64 // number of mallocs
-       Frees      uint64 // number of frees
-
-       // Main allocation heap statistics.
-       HeapAlloc   uint64 // bytes allocated and still in use
-       HeapSys     uint64 // bytes obtained from system
-       HeapIdle    uint64 // bytes in idle spans
-       HeapInuse   uint64 // bytes in non-idle span
-       HeapObjects uint64 // total number of allocated objects
-
-       // Low-level fixed-size structure allocator statistics.
-       //      Inuse is bytes used now.
-       //      Sys is bytes obtained from system.
-       StackInuse  uint64 // bootstrap stacks
-       StackSys    uint64
-       MSpanInuse  uint64 // mspan structures
-       MSpanSys    uint64
-       MCacheInuse uint64 // mcache structures
-       MCacheSys   uint64
-       BuckHashSys uint64 // profiling bucket hash table
-
-       // Garbage collector statistics.
-       NextGC       uint64
-       PauseTotalNs uint64
-       PauseNs      [256]uint64 // most recent GC pause times
-       NumGC        uint32
-       EnableGC     bool
-       DebugGC      bool
-
-       // Per-size allocation statistics.
-       // Not locked during update; approximate.
-       // 61 is NumSizeClasses in the C code.
-       BySize [61]struct {
-               Size    uint32
-               Mallocs uint64
-               Frees   uint64
-       }
-}
-
-var Sizeof_C_MStats int // filled in by malloc.goc
-
-func init() {
-       if Sizeof_C_MStats != unsafe.Sizeof(MemStats) {
-               println(Sizeof_C_MStats, unsafe.Sizeof(MemStats))
-               panic("MStats vs MemStatsType size mismatch")
-       }
-}
-
-// MemStats holds statistics about the memory system.
-// The statistics are only approximate, as they are not interlocked on update.
-var MemStats MemStatsType
-
 // Alloc allocates a block of the given size.
 // FOR TESTING AND DEBUGGING ONLY.
 func Alloc(uintptr) *byte
@@ -102,9 +41,6 @@ func Free(*byte)
 // FOR TESTING AND DEBUGGING ONLY.
 func Lookup(*byte) (*byte, uintptr)
 
-// GC runs a garbage collection.
-func GC()
-
 // MemProfileRate controls the fraction of memory allocations
 // that are recorded and reported in the memory profile.
 // The profiler aims to sample an average of
@@ -156,4 +92,24 @@ func (r *MemProfileRecord) Stack() []uintptr {
 // where r.AllocBytes > 0 but r.AllocBytes == r.FreeBytes.
 // These are sites where memory was allocated, but it has all
 // been released back to the runtime.
+// Most clients should use the runtime/pprof package or
+// the testing package's -test.memprofile flag instead
+// of calling MemProfile directly.
 func MemProfile(p []MemProfileRecord, inuseZero bool) (n int, ok bool)
+
+// CPUProfile returns the next chunk of binary CPU profiling stack trace data,
+// blocking until data is available.  If profiling is turned off and all the profile
+// data accumulated while it was on has been returned, CPUProfile returns nil.
+// The caller must save the returned data before calling CPUProfile again.
+// Most clients should use the runtime/pprof package or
+// the testing package's -test.cpuprofile flag instead of calling
+// CPUProfile directly.
+func CPUProfile() []byte
+
+// SetCPUProfileRate sets the CPU profiling rate to hz samples per second.
+// If hz <= 0, SetCPUProfileRate turns off profiling.
+// If the profiler is on, the rate cannot be changed without first turning it off.
+// Most clients should use the runtime/pprof package or
+// the testing package's -test.cpuprofile flag instead of calling
+// SetCPUProfileRate directly.
+func SetCPUProfileRate(hz int)
diff --git a/libgo/go/runtime/mem.go b/libgo/go/runtime/mem.go
new file mode 100644 (file)
index 0000000..2fc1892
--- /dev/null
@@ -0,0 +1,69 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package runtime
+
+import "unsafe"
+
+type MemStatsType struct {
+       // General statistics.
+       // Not locked during update; approximate.
+       Alloc      uint64 // bytes allocated and still in use
+       TotalAlloc uint64 // bytes allocated (even if freed)
+       Sys        uint64 // bytes obtained from system (should be sum of XxxSys below)
+       Lookups    uint64 // number of pointer lookups
+       Mallocs    uint64 // number of mallocs
+       Frees      uint64 // number of frees
+
+       // Main allocation heap statistics.
+       HeapAlloc   uint64 // bytes allocated and still in use
+       HeapSys     uint64 // bytes obtained from system
+       HeapIdle    uint64 // bytes in idle spans
+       HeapInuse   uint64 // bytes in non-idle span
+       HeapObjects uint64 // total number of allocated objects
+
+       // Low-level fixed-size structure allocator statistics.
+       //      Inuse is bytes used now.
+       //      Sys is bytes obtained from system.
+       StackInuse  uint64 // bootstrap stacks
+       StackSys    uint64
+       MSpanInuse  uint64 // mspan structures
+       MSpanSys    uint64
+       MCacheInuse uint64 // mcache structures
+       MCacheSys   uint64
+       BuckHashSys uint64 // profiling bucket hash table
+
+       // Garbage collector statistics.
+       NextGC       uint64
+       PauseTotalNs uint64
+       PauseNs      [256]uint64 // most recent GC pause times
+       NumGC        uint32
+       EnableGC     bool
+       DebugGC      bool
+
+       // Per-size allocation statistics.
+       // Not locked during update; approximate.
+       // 61 is NumSizeClasses in the C code.
+       BySize [61]struct {
+               Size    uint32
+               Mallocs uint64
+               Frees   uint64
+       }
+}
+
+var Sizeof_C_MStats int // filled in by malloc.goc
+
+func init() {
+       if Sizeof_C_MStats != unsafe.Sizeof(MemStats) {
+               println(Sizeof_C_MStats, unsafe.Sizeof(MemStats))
+               panic("MStats vs MemStatsType size mismatch")
+       }
+}
+
+// MemStats holds statistics about the memory system.
+// The statistics are only approximate, as they are not interlocked on update.
+var MemStats MemStatsType
+
+// GC runs a garbage collection.
+func GC()
index 9bee5112819dff25ded674bbe3023ec77973f6a6..fdeceb4e8dc6747e612615f92c38cceccc188b8b 100644 (file)
@@ -14,6 +14,7 @@ import (
        "io"
        "os"
        "runtime"
+       "sync"
 )
 
 // WriteHeapProfile writes a pprof-formatted heap profile to w.
@@ -105,3 +106,71 @@ func WriteHeapProfile(w io.Writer) os.Error {
        }
        return b.Flush()
 }
+
+var cpu struct {
+       sync.Mutex
+       profiling bool
+       done      chan bool
+}
+
+// StartCPUProfile enables CPU profiling for the current process.
+// While profiling, the profile will be buffered and written to w.
+// StartCPUProfile returns an error if profiling is already enabled.
+func StartCPUProfile(w io.Writer) os.Error {
+       // The runtime routines allow a variable profiling rate,
+       // but in practice operating systems cannot trigger signals
+       // at more than about 500 Hz, and our processing of the
+       // signal is not cheap (mostly getting the stack trace).
+       // 100 Hz is a reasonable choice: it is frequent enough to
+       // produce useful data, rare enough not to bog down the
+       // system, and a nice round number to make it easy to
+       // convert sample counts to seconds.  Instead of requiring
+       // each client to specify the frequency, we hard code it.
+       const hz = 100
+
+       // Avoid queueing behind StopCPUProfile.
+       // Could use TryLock instead if we had it.
+       if cpu.profiling {
+               return fmt.Errorf("cpu profiling already in use")
+       }
+
+       cpu.Lock()
+       defer cpu.Unlock()
+       if cpu.done == nil {
+               cpu.done = make(chan bool)
+       }
+       // Double-check.
+       if cpu.profiling {
+               return fmt.Errorf("cpu profiling already in use")
+       }
+       cpu.profiling = true
+       runtime.SetCPUProfileRate(hz)
+       go profileWriter(w)
+       return nil
+}
+
+func profileWriter(w io.Writer) {
+       for {
+               data := runtime.CPUProfile()
+               if data == nil {
+                       break
+               }
+               w.Write(data)
+       }
+       cpu.done <- true
+}
+
+// StopCPUProfile stops the current CPU profile, if any.
+// StopCPUProfile only returns after all the writes for the
+// profile have completed.
+func StopCPUProfile() {
+       cpu.Lock()
+       defer cpu.Unlock()
+
+       if !cpu.profiling {
+               return
+       }
+       cpu.profiling = false
+       runtime.SetCPUProfileRate(0)
+       <-cpu.done
+}
diff --git a/libgo/go/runtime/pprof/pprof_test.go b/libgo/go/runtime/pprof/pprof_test.go
new file mode 100644 (file)
index 0000000..603465e
--- /dev/null
@@ -0,0 +1,69 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pprof_test
+
+import (
+       "bytes"
+       "hash/crc32"
+       "runtime"
+       . "runtime/pprof"
+       "strings"
+       "testing"
+       "unsafe"
+)
+
+func TestCPUProfile(t *testing.T) {
+       if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
+               return
+       }
+
+       buf := make([]byte, 100000)
+       var prof bytes.Buffer
+       if err := StartCPUProfile(&prof); err != nil {
+               t.Fatal(err)
+       }
+       // This loop takes about a quarter second on a 2 GHz laptop.
+       // We only need to get one 100 Hz clock tick, so we've got
+       // a 25x safety buffer.
+       for i := 0; i < 1000; i++ {
+               crc32.ChecksumIEEE(buf)
+       }
+       StopCPUProfile()
+
+       // Convert []byte to []uintptr.
+       bytes := prof.Bytes()
+       val := *(*[]uintptr)(unsafe.Pointer(&bytes))
+       val = val[:len(bytes)/unsafe.Sizeof(uintptr(0))]
+
+       if len(val) < 10 {
+               t.Fatalf("profile too short: %#x", val)
+       }
+       if val[0] != 0 || val[1] != 3 || val[2] != 0 || val[3] != 1e6/100 || val[4] != 0 {
+               t.Fatalf("unexpected header %#x", val[:5])
+       }
+
+       // Check that profile is well formed and contains ChecksumIEEE.
+       found := false
+       val = val[5:]
+       for len(val) > 0 {
+               if len(val) < 2 || val[0] < 1 || val[1] < 1 || uintptr(len(val)) < 2+val[1] {
+                       t.Fatalf("malformed profile.  leftover: %#x", val)
+               }
+               for _, pc := range val[2 : 2+val[1]] {
+                       f := runtime.FuncForPC(pc)
+                       if f == nil {
+                               continue
+                       }
+                       if strings.Contains(f.Name(), "ChecksumIEEE") {
+                               found = true
+                       }
+               }
+               val = val[2+val[1]:]
+       }
+
+       if !found {
+               t.Fatal("did not find ChecksumIEEE in the profile")
+       }
+}
index 98a0d5731ec3d9f547e1f8e7e6206baa990fa718..5f009e54859146fb2d17345686758006d496277b 100644 (file)
@@ -119,9 +119,19 @@ func LastIndex(s, sep string) int {
 // IndexRune returns the index of the first instance of the Unicode code point
 // rune, or -1 if rune is not present in s.
 func IndexRune(s string, rune int) int {
-       for i, c := range s {
-               if c == rune {
-                       return i
+       switch {
+       case rune < 0x80:
+               b := byte(rune)
+               for i := 0; i < len(s); i++ {
+                       if s[i] == b {
+                               return i
+                       }
+               }
+       default:
+               for i, c := range s {
+                       if c == rune {
+                               return i
+                       }
                }
        }
        return -1
index 734fdd33daa71f1ca90e5b6527c3a03ebb6120cf..41e398782e686d6c60ef785de276106973a6dd6e 100644 (file)
@@ -6,6 +6,7 @@ package strings_test
 
 import (
        "os"
+       "strconv"
        . "strings"
        "testing"
        "unicode"
@@ -116,6 +117,57 @@ func TestLastIndex(t *testing.T)    { runIndexTests(t, LastIndex, "LastIndex", l
 func TestIndexAny(t *testing.T)     { runIndexTests(t, IndexAny, "IndexAny", indexAnyTests) }
 func TestLastIndexAny(t *testing.T) { runIndexTests(t, LastIndexAny, "LastIndexAny", lastIndexAnyTests) }
 
+type IndexRuneTest struct {
+       s    string
+       rune int
+       out  int
+}
+
+var indexRuneTests = []IndexRuneTest{
+       {"a A x", 'A', 2},
+       {"some_text=some_value", '=', 9},
+       {"☺a", 'a', 3},
+       {"a☻☺b", '☺', 4},
+}
+
+func TestIndexRune(t *testing.T) {
+       for _, test := range indexRuneTests {
+               if actual := IndexRune(test.s, test.rune); actual != test.out {
+                       t.Errorf("IndexRune(%q,%d)= %v; want %v", test.s, test.rune, actual, test.out)
+               }
+       }
+}
+
+const benchmarkString = "some_text=some☺value"
+
+func BenchmarkIndexRune(b *testing.B) {
+       if got := IndexRune(benchmarkString, '☺'); got != 14 {
+               panic("wrong index: got=" + strconv.Itoa(got))
+       }
+       for i := 0; i < b.N; i++ {
+               IndexRune(benchmarkString, '☺')
+       }
+}
+
+func BenchmarkIndexRuneFastPath(b *testing.B) {
+       if got := IndexRune(benchmarkString, 'v'); got != 17 {
+               panic("wrong index: got=" + strconv.Itoa(got))
+       }
+       for i := 0; i < b.N; i++ {
+               IndexRune(benchmarkString, 'v')
+       }
+}
+
+func BenchmarkIndex(b *testing.B) {
+       if got := Index(benchmarkString, "v"); got != 17 {
+               panic("wrong index: got=" + strconv.Itoa(got))
+       }
+       for i := 0; i < b.N; i++ {
+               Index(benchmarkString, "v")
+       }
+}
+
+
 type ExplodeTest struct {
        s string
        n int
index 68e1d509f4811c8161901396fd880f2d098289e7..05478c630667eb705515e91305b331a850d49c45 100644 (file)
@@ -22,7 +22,7 @@ import "runtime"
 //       go func() {
 //           // Do something.
 //           wg.Done()
-//       }
+//       }()
 //   }
 //   wg.Wait()
 // 
index c3cb8901a09f473290ee60d3fec1e6ef60d7a50b..ba06de4e3ab024cbff8ddb72d5c8a3230f584426 100644 (file)
@@ -267,7 +267,6 @@ func (t *Template) nextItem() []byte {
        }
        leadingSpace := i > start
        // What's left is nothing, newline, delimited string, or plain text
-Switch:
        switch {
        case i == len(t.buf):
                // EOF; nothing to do
@@ -896,8 +895,8 @@ func (t *Template) executeRepeated(r *repeatedElement, st *state) {
                }
        } else if ch := iter(field); ch != nil {
                for {
-                       e := ch.Recv()
-                       if ch.Closed() {
+                       e, ok := ch.Recv()
+                       if !ok {
                                break
                        }
                        loopBody(st.clone(e))
index 11f5a74251ade832819155b2a0cc85bdc36af96e..b341b1f896b98e1e29daa319ade2289a73aa06bb 100644 (file)
@@ -306,8 +306,8 @@ func recvValues(multiplex chan<- interface{}, channel interface{}) {
        c := reflect.NewValue(channel).(*reflect.ChanValue)
 
        for {
-               v := c.Recv()
-               if c.Closed() {
+               v, ok := c.Recv()
+               if !ok {
                        multiplex <- channelClosed{channel}
                        return
                }
index 324b5a70e1fa3c64899c69c91f462c9de1139a93..ab8cf999a2564dd6c3d7c0b4f74f79bc60218e5c 100644 (file)
@@ -43,12 +43,18 @@ import (
        "fmt"
        "os"
        "runtime"
+       "runtime/pprof"
        "time"
 )
 
-// Report as tests are run; default is silent for success.
-var chatty = flag.Bool("test.v", false, "verbose: print additional output")
-var match = flag.String("test.run", "", "regular expression to select tests to run")
+var (
+       // Report as tests are run; default is silent for success.
+       chatty         = flag.Bool("test.v", false, "verbose: print additional output")
+       match          = flag.String("test.run", "", "regular expression to select tests to run")
+       memProfile     = flag.String("test.memprofile", "", "write a memory profile to the named file after execution")
+       memProfileRate = flag.Int("test.memprofilerate", 0, "if >=0, sets runtime.MemProfileRate")
+       cpuProfile     = flag.String("test.cpuprofile", "", "write a cpu profile to the named file during execution")
+)
 
 
 // Insert final newline if needed and tabs after internal newlines.
@@ -136,8 +142,16 @@ func tRunner(t *T, test *InternalTest) {
 
 // An internal function but exported because it is cross-package; part of the implementation
 // of gotest.
-func Main(matchString func(pat, str string) (bool, os.Error), tests []InternalTest) {
+func Main(matchString func(pat, str string) (bool, os.Error), tests []InternalTest, benchmarks []InternalBenchmark) {
        flag.Parse()
+
+       before()
+       RunTests(matchString, tests)
+       RunBenchmarks(matchString, benchmarks)
+       after()
+}
+
+func RunTests(matchString func(pat, str string) (bool, os.Error), tests []InternalTest) {
        ok := true
        if len(tests) == 0 {
                println("testing: warning: no tests to run")
@@ -176,3 +190,42 @@ func Main(matchString func(pat, str string) (bool, os.Error), tests []InternalTe
        }
        println("PASS")
 }
+
+// before runs before all testing.
+func before() {
+       if *memProfileRate > 0 {
+               runtime.MemProfileRate = *memProfileRate
+       }
+       if *cpuProfile != "" {
+               f, err := os.Open(*cpuProfile, os.O_WRONLY|os.O_CREAT|os.O_TRUNC, 0666)
+               if err != nil {
+                       fmt.Fprintf(os.Stderr, "testing: %s", err)
+                       return
+               }
+               if err := pprof.StartCPUProfile(f); err != nil {
+                       fmt.Fprintf(os.Stderr, "testing: can't start cpu profile: %s", err)
+                       f.Close()
+                       return
+               }
+               // Could save f so after can call f.Close; not worth the effort.
+       }
+
+}
+
+// after runs after all testing.
+func after() {
+       if *cpuProfile != "" {
+               pprof.StopCPUProfile() // flushes profile to disk
+       }
+       if *memProfile != "" {
+               f, err := os.Open(*memProfile, os.O_WRONLY|os.O_CREAT|os.O_TRUNC, 0666)
+               if err != nil {
+                       fmt.Fprintf(os.Stderr, "testing: %s", err)
+                       return
+               }
+               if err = pprof.WriteHeapProfile(f); err != nil {
+                       fmt.Fprintf(os.Stderr, "testing: can't write %s: %s", *memProfile, err)
+               }
+               f.Close()
+       }
+}
index 833552d68412ca22fe0495a66223cd02f9c3d776..3bc253c94a3908c891bdd6da4bb9093fa68f7d18 100644 (file)
@@ -5,10 +5,8 @@
 package time
 
 import (
-       "os"
-       "syscall"
-       "sync"
        "container/heap"
+       "sync"
 )
 
 // The Timer type represents a single event.
@@ -47,30 +45,6 @@ func init() {
        timers.Push(&Timer{t: forever}) // sentinel
 }
 
-// Sleep pauses the current goroutine for at least ns nanoseconds.
-// Higher resolution sleeping may be provided by syscall.Nanosleep 
-// on some operating systems.
-func Sleep(ns int64) os.Error {
-       _, err := sleep(Nanoseconds(), ns)
-       return err
-}
-
-// sleep takes the current time and a duration,
-// pauses for at least ns nanoseconds, and
-// returns the current time and an error.
-func sleep(t, ns int64) (int64, os.Error) {
-       // TODO(cw): use monotonic-time once it's available
-       end := t + ns
-       for t < end {
-               errno := syscall.Sleep(end - t)
-               if errno != 0 && errno != syscall.EINTR {
-                       return 0, os.NewSyscallError("sleep", errno)
-               }
-               t = Nanoseconds()
-       }
-       return t, nil
-}
-
 // NewTimer creates a new Timer that will send
 // the current time on its channel after at least ns nanoseconds.
 func NewTimer(ns int64) *Timer {
@@ -151,7 +125,7 @@ func sleeper(sleeperId int64) {
                                dt = maxSleepTime
                        }
                        timerMutex.Unlock()
-                       syscall.Sleep(dt)
+                       sysSleep(dt)
                        timerMutex.Lock()
                        if currentSleeper != sleeperId {
                                // Another sleeper has been started, making this one redundant.
index 8bf599c3e1a96d3dc5a893331692b10efa0a43a7..5fe4d7f15b5c3eff3d1b55c71048d04b5ce726ba 100644 (file)
@@ -132,7 +132,9 @@ func TestAfterStop(t *testing.T) {
        }
 }
 
-var slots = []int{5, 3, 6, 6, 6, 1, 1, 2, 7, 9, 4, 8, 0}
+// For gccgo omit 0 for now because it can take too long to start the
+// thread.
+var slots = []int{5, 3, 6, 6, 6, 1, 1, 2, 7, 9, 4, 8, /*0*/}
 
 type afterResult struct {
        slot int
diff --git a/libgo/go/time/sys.go b/libgo/go/time/sys.go
new file mode 100644 (file)
index 0000000..63f4cbf
--- /dev/null
@@ -0,0 +1,62 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package time
+
+import (
+       "os"
+       "syscall"
+)
+
+// Seconds reports the number of seconds since the Unix epoch,
+// January 1, 1970 00:00:00 UTC.
+func Seconds() int64 {
+       sec, _, err := os.Time()
+       if err != nil {
+               panic(err)
+       }
+       return sec
+}
+
+// Nanoseconds reports the number of nanoseconds since the Unix epoch,
+// January 1, 1970 00:00:00 UTC.
+func Nanoseconds() int64 {
+       sec, nsec, err := os.Time()
+       if err != nil {
+               panic(err)
+       }
+       return sec*1e9 + nsec
+}
+
+// Sleep pauses the current goroutine for at least ns nanoseconds.
+// Higher resolution sleeping may be provided by syscall.Nanosleep 
+// on some operating systems.
+func Sleep(ns int64) os.Error {
+       _, err := sleep(Nanoseconds(), ns)
+       return err
+}
+
+// sleep takes the current time and a duration,
+// pauses for at least ns nanoseconds, and
+// returns the current time and an error.
+func sleep(t, ns int64) (int64, os.Error) {
+       // TODO(cw): use monotonic-time once it's available
+       end := t + ns
+       for t < end {
+               err := sysSleep(end - t)
+               if err != nil {
+                       return 0, err
+               }
+               t = Nanoseconds()
+       }
+       return t, nil
+}
+
+func sysSleep(t int64) os.Error {
+       errno := syscall.Sleep(t)
+       if errno != 0 && errno != syscall.EINTR {
+               return os.NewSyscallError("sleep", errno)
+       }
+       return nil
+}
index 4abd112308bbfccd4d2de1d50afd8f21a63377d4..40338f7752a48c61c2db906be9018e660bb2ec4e 100644 (file)
@@ -6,30 +6,6 @@
 // displaying time.
 package time
 
-import (
-       "os"
-)
-
-// Seconds reports the number of seconds since the Unix epoch,
-// January 1, 1970 00:00:00 UTC.
-func Seconds() int64 {
-       sec, _, err := os.Time()
-       if err != nil {
-               panic(err)
-       }
-       return sec
-}
-
-// Nanoseconds reports the number of nanoseconds since the Unix epoch,
-// January 1, 1970 00:00:00 UTC.
-func Nanoseconds() int64 {
-       sec, nsec, err := os.Time()
-       if err != nil {
-               panic(err)
-       }
-       return sec*1e9 + nsec
-}
-
 // Days of the week.
 const (
        Sunday = iota
@@ -47,7 +23,7 @@ type Time struct {
        Month, Day           int    // Jan-2 is 1, 2
        Hour, Minute, Second int    // 15:04:05 is 15, 4, 5.
        Weekday              int    // Sunday, Monday, ...
-       ZoneOffset           int    // seconds east of UTC, e.g. -7*60 for -0700
+       ZoneOffset           int    // seconds east of UTC, e.g. -7*60*60 for -0700
        Zone                 string // e.g., "MST"
 }
 
index c86bca1b49875c823ed1a524e3e157256f77b64c..1d83291c097a8bd88bba1049218ed4db74c42fbe 100644 (file)
@@ -19,6 +19,18 @@ func init() {
        os.Setenv("TZ", "America/Los_Angeles")
 }
 
+// We should be in PST/PDT, but if the time zone files are missing we
+// won't be. The purpose of this test is to at least explain why some of
+// the subsequent tests fail.
+func TestZoneData(t *testing.T) {
+       lt := LocalTime()
+       // PST is 8 hours west, PDT is 7 hours west.  We could use the name but it's not unique.
+       if off := lt.ZoneOffset; off != -8*60*60 && off != -7*60*60 {
+               t.Errorf("Unable to find US Pacific time zone data for testing; time zone is %q offset %d", lt.Zone, off)
+               t.Error("Likely problem: the time zone files have not been installed.")
+       }
+}
+
 type TimeTest struct {
        seconds int64
        golden  Time
index 25f057ba5b0b4d7763aad84b97b9be9dc1262469..1119b2d34ebdf906e14f0a96e962a891352d09c5 100644 (file)
@@ -58,7 +58,7 @@ func getKeyNumber(s string) (r uint32) {
 
 // ServeHTTP implements the http.Handler interface for a Web Socket
 func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       rwc, buf, err := w.Hijack()
+       rwc, buf, err := w.(http.Hijacker).Hijack()
        if err != nil {
                panic("Hijack failed: " + err.String())
                return
@@ -98,7 +98,7 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
        }
 
        var location string
-       if w.UsingTLS() {
+       if req.TLS != nil {
                location = "wss://" + req.Host + req.URL.RawPath
        } else {
                location = "ws://" + req.Host + req.URL.RawPath
@@ -184,7 +184,7 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
                return
        }
 
-       rwc, buf, err := w.Hijack()
+       rwc, buf, err := w.(http.Hijacker).Hijack()
        if err != nil {
                panic("Hijack failed: " + err.String())
                return
@@ -192,7 +192,7 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
        defer rwc.Close()
 
        var location string
-       if w.UsingTLS() {
+       if req.TLS != nil {
                location = "wss://" + req.Host + req.URL.RawPath
        } else {
                location = "ws://" + req.Host + req.URL.RawPath
index 204a9de1e12957125464a6500677444bd438cc75..14d708a3babd1e77ce9a7f98965c4ee839f7230e 100644 (file)
@@ -9,6 +9,7 @@ import (
        "bytes"
        "fmt"
        "http"
+       "http/httptest"
        "io"
        "log"
        "net"
@@ -22,15 +23,11 @@ var once sync.Once
 func echoServer(ws *Conn) { io.Copy(ws, ws) }
 
 func startServer() {
-       l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
-       if e != nil {
-               log.Fatalf("net.Listen tcp :0 %v", e)
-       }
-       serverAddr = l.Addr().String()
-       log.Print("Test WebSocket server listening on ", serverAddr)
        http.Handle("/echo", Handler(echoServer))
        http.Handle("/echoDraft75", Draft75Handler(echoServer))
-       go http.Serve(l, nil)
+       server := httptest.NewServer(nil)
+       serverAddr = server.Listener.Addr().String()
+       log.Print("Test WebSocket server listening on ", serverAddr)
 }
 
 // Test the getChallengeResponse function with values from section
index 691c13a1188a4f8c11f620bf46f9a512e265037d..f92abe825603f79e4b44b37e831167b356abcc57 100644 (file)
@@ -815,7 +815,6 @@ Input:
                        // Parsers are required to recognize lt, gt, amp, apos, and quot
                        // even if they have not been declared.  That's all we allow.
                        var i int
-               CharLoop:
                        for i = 0; i < len(p.tmp); i++ {
                                var ok bool
                                p.tmp[i], ok = p.getc()
index 17dc57340240b0a6ed819faf1d1a31681e9fac38..3696783abbe3277c496a126ba142c55dc0a10bb1 100644 (file)
@@ -146,7 +146,7 @@ done
   done
 done
 
-runtime="goc2c.c mcache.c mcentral.c mfinal.c mfixalloc.c mgc0.c mheap.c mheapmap32.c mheapmap64.c msize.c malloc.h mheapmap32.h mheapmap64.h malloc.goc mprof.goc"
+runtime="goc2c.c mcache.c mcentral.c mfinal.c mfixalloc.c mgc0.c mheap.c msize.c malloc.h malloc.goc mprof.goc"
 for f in $runtime; do
   oldfile=${OLDDIR}/src/pkg/runtime/$f
   newfile=${NEWDIR}/src/pkg/runtime/$f
index ea108289bf3974df4d697396e7577a9c62f3863a..743af8bee9ace6886ef46ee82baf972b96da06c9 100644 (file)
@@ -119,8 +119,12 @@ extern void __go_receive_release (struct __go_channel *);
 
 struct __go_receive_nonblocking_small
 {
+  /* Value read from channel, or 0.  */
   uint64_t __val;
+  /* True if value was read from channel.  */
   _Bool __success;
+  /* True if channel is closed.  */
+  _Bool __closed;
 };
 
 extern struct __go_receive_nonblocking_small
@@ -128,7 +132,8 @@ __go_receive_nonblocking_small (struct __go_channel *);
 
 extern _Bool __go_receive_big (struct __go_channel *, void *, _Bool);
 
-extern _Bool __go_receive_nonblocking_big (struct __go_channel *, void *);
+extern _Bool __go_receive_nonblocking_big (struct __go_channel *, void *,
+                                          _Bool *);
 
 extern void __go_unlock_and_notify_selects (struct __go_channel *);
 
index a584fe7a2731614714094e53ce0981b2827a2f30..fd3923ce272e1137ac14ba8d1c66be46d7b84fc0 100644 (file)
@@ -9,6 +9,9 @@
 #include "go-panic.h"
 #include "channel.h"
 
+/* Returns true if a value was received, false if the channel is
+   closed.  */
+
 _Bool
 __go_receive_big (struct __go_channel *channel, void *val, _Bool for_select)
 {
index 53ffe48ab97d69d3023c4e7b1ef5962036ca75c9..78db587345f7280ae7d35d8b399b43725f7b31c9 100644 (file)
@@ -8,8 +8,11 @@
 
 #include "channel.h"
 
+/* Return true if a value was received, false if not.  */
+
 _Bool
-__go_receive_nonblocking_big (struct __go_channel* channel, void *val)
+__go_receive_nonblocking_big (struct __go_channel* channel, void *val,
+                             _Bool *closed)
 {
   size_t alloc_size;
   size_t offset;
@@ -21,13 +24,9 @@ __go_receive_nonblocking_big (struct __go_channel* channel, void *val)
   if (data != RECEIVE_NONBLOCKING_ACQUIRE_DATA)
     {
       __builtin_memset (val, 0, channel->element_size);
-      if (data == RECEIVE_NONBLOCKING_ACQUIRE_NODATA)
-       return 0;
-      else
-       {
-         /* Channel is closed.  */
-         return 1;
-       }
+      if (closed != NULL)
+       *closed = data == RECEIVE_NONBLOCKING_ACQUIRE_CLOSED;
+      return 0;
     }
 
   offset = channel->next_fetch * alloc_size;
index d77a2ace432676e8b95d9fb93d8aef74c7dc056f..d09901b0c88365348c1a5f6cee532b82b541a529 100644 (file)
@@ -103,7 +103,8 @@ __go_receive_nonblocking_small (struct __go_channel *channel)
   if (data != RECEIVE_NONBLOCKING_ACQUIRE_DATA)
     {
       ret.__val = 0;
-      ret.__success = data == RECEIVE_NONBLOCKING_ACQUIRE_CLOSED;
+      ret.__success = 0;
+      ret.__closed = data == RECEIVE_NONBLOCKING_ACQUIRE_CLOSED;
       return ret;
     }
 
@@ -112,6 +113,7 @@ __go_receive_nonblocking_small (struct __go_channel *channel)
   __go_receive_release (channel);
 
   ret.__success = 1;
+  ret.__closed = 0;
 
   return ret;
 }
index 87aed3cd552fd3fef6b95d27e671c3802c6456d7..ee85cde566be4e5457146c9bab981a8a42f3146b 100644 (file)
@@ -96,7 +96,8 @@ __go_broadcast_to_select (struct __go_channel *channel)
 }
 
 /* Prepare to receive something on a channel.  Return true if the
-   channel is acquired, false if it is closed.  */
+   channel is acquired (which implies that there is data available),
+   false if it is closed.  */
 
 _Bool
 __go_receive_acquire (struct __go_channel *channel, _Bool for_select)
index 412cfeedfe37de22169c27dfe555a58084b803ed..6ec1b9a2dc7c0342577b242325ad0efacbfaafc2 100644 (file)
@@ -27,7 +27,7 @@ extern void chansend (unsigned char *, unsigned char *, _Bool *)
   asm ("libgo_reflect.reflect.chansend");
 
 void
-chansend (unsigned char *ch, unsigned char *val, _Bool *pres)
+chansend (unsigned char *ch, unsigned char *val, _Bool *selected)
 {
   struct __go_channel *channel = (struct __go_channel *) ch;
 
@@ -46,25 +46,26 @@ chansend (unsigned char *ch, unsigned char *val, _Bool *pres)
       __builtin_memcpy (u.b + sizeof (uint64_t) - channel->element_size, val,
                        channel->element_size);
 #endif
-      if (pres == NULL)
+      if (selected == NULL)
        __go_send_small (channel, u.v, 0);
       else
-       *pres = __go_send_nonblocking_small (channel, u.v);
+       *selected = __go_send_nonblocking_small (channel, u.v);
     }
   else
     {
-      if (pres == NULL)
+      if (selected == NULL)
        __go_send_big (channel, val, 0);
       else
-       *pres = __go_send_nonblocking_big (channel, val);
+       *selected = __go_send_nonblocking_big (channel, val);
     }
 }
 
-extern void chanrecv (unsigned char *, unsigned char *, _Bool *)
+extern void chanrecv (unsigned char *, unsigned char *, _Bool *, _Bool *)
   asm ("libgo_reflect.reflect.chanrecv");
 
 void
-chanrecv (unsigned char *ch, unsigned char *val, _Bool *pres)
+chanrecv (unsigned char *ch, unsigned char *val, _Bool *selected,
+         _Bool *received)
 {
   struct __go_channel *channel = (struct __go_channel *) ch;
 
@@ -76,16 +77,16 @@ chanrecv (unsigned char *ch, unsigned char *val, _Bool *pres)
        uint64_t v;
       } u;
 
-      if (pres == NULL)
-       u.v = __go_receive_small (channel, 0);
+      if (selected == NULL)
+       u.v = __go_receive_small_closed (channel, 0, received);
       else
        {
          struct __go_receive_nonblocking_small s;
 
          s = __go_receive_nonblocking_small (channel);
-         *pres = s.__success;
-         if (!s.__success)
-           return;
+         *selected = s.__success || s.__closed;
+         if (received != NULL)
+           *received = s.__success;
          u.v = s.__val;
        }
 
@@ -98,10 +99,24 @@ chanrecv (unsigned char *ch, unsigned char *val, _Bool *pres)
     }
   else
     {
-      if (pres == NULL)
-       __go_receive_big (channel, val, 0);
+      if (selected == NULL)
+       {
+         _Bool success;
+
+         success = __go_receive_big (channel, val, 0);
+         if (received != NULL)
+           *received = success;
+       }
       else
-       *pres = __go_receive_nonblocking_big (channel, val);
+       {
+         _Bool got;
+         _Bool closed;
+
+         got = __go_receive_nonblocking_big (channel, val, &closed);
+         *selected = got || closed;
+         if (received != NULL)
+           *received = got;
+       }
     }
 }
 
index 7a9ee2825dccd0f7732e560a1fccdb1fedd18cdf..64b40bdc955ce8c1e2ad7700b694526dee5c3513 100644 (file)
@@ -12,6 +12,7 @@ import "unsafe"
 
 func libc_fcntl(fd int, cmd int, arg int) int __asm__ ("fcntl")
 func libc_fork() Pid_t __asm__ ("fork")
+func libc_setsid() Pid_t __asm__ ("setsid")
 func libc_chdir(name *byte) int __asm__ ("chdir")
 func libc_dup2(int, int) int __asm__ ("dup2")
 func libc_execve(*byte, **byte, **byte) int __asm__ ("execve")
@@ -24,13 +25,16 @@ func libc_wait4(Pid_t, *int, int, *Rusage) Pid_t __asm__ ("wait4")
 // In the child, this function must not acquire any locks, because
 // they might have been locked at the time of the fork.  This means
 // no rescheduling, no malloc calls, and no new stack segments.
-func forkAndExecInChild(argv0 *byte, argv []*byte, envv []*byte, traceme bool, dir *byte, fd []int, pipe int) (pid int, err int) {
+func forkAndExecInChild(argv0 *byte, argv, envv []*byte, dir *byte, attr *ProcAttr, pipe int) (pid int, err int) {
        // Declare all variables at top in case any
        // declarations require heap allocation (e.g., err1).
        var r1, r2, err1 uintptr
        var nextfd int
        var i int
 
+       // guard against side effects of shuffling fds below.
+       fd := append([]int(nil), attr.Files...)
+
        darwin := OS == "darwin"
 
        // About to call fork.
@@ -48,16 +52,22 @@ func forkAndExecInChild(argv0 *byte, argv []*byte, envv []*byte, traceme bool, d
        // Fork succeeded, now in child.
 
        // Enable tracing if requested.
-       if traceme {
+       if attr.Ptrace {
                if libc_ptrace(_PTRACE_TRACEME, 0, 0, nil) < 0 {
                        goto childerror
                }
        }
 
+       // Session ID
+       if attr.Setsid {
+               if libc_setsid() == Pid_t(-1) {
+                       goto childerror
+               }
+       }
+
        // Chdir
        if dir != nil {
-               r := libc_chdir(dir)
-               if r < 0 {
+               if libc_chdir(dir) < 0 {
                        goto childerror
                }
        }
@@ -138,22 +148,42 @@ childerror:
        panic("unreached")
 }
 
-func forkExec(argv0 string, argv []string, envv []string, traceme bool, dir string, fd []int) (pid int, err int) {
+
+type ProcAttr struct {
+       Setsid bool     // Create session.
+       Ptrace bool     // Enable tracing.
+       Dir    string   // Current working directory.
+       Env    []string // Environment.
+       Files  []int    // File descriptors.
+}
+
+var zeroAttributes ProcAttr
+
+func forkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err int) {
        var p [2]int
        var r1 int
        var err1 uintptr
        var wstatus WaitStatus
 
+       if attr == nil {
+               attr = &zeroAttributes
+       }
+
        p[0] = -1
        p[1] = -1
 
        // Convert args to C form.
        argv0p := StringBytePtr(argv0)
        argvp := StringArrayPtr(argv)
-       envvp := StringArrayPtr(envv)
-       var dirp *byte
-       if len(dir) > 0 {
-               dirp = StringBytePtr(dir)
+       envvp := StringArrayPtr(attr.Env)
+
+       if OS == "freebsd" && len(argv[0]) > len(argv0) {
+               argvp[0] = argv0p
+       }
+
+       var dir *byte
+       if attr.Dir != "" {
+               dir = StringBytePtr(attr.Dir)
        }
 
        // Acquire the fork lock so that no other threads
@@ -173,7 +203,7 @@ func forkExec(argv0 string, argv []string, envv []string, traceme bool, dir stri
        }
 
        // Kick off child.
-       pid, err = forkAndExecInChild(argv0p, argvp, envvp, traceme, dirp, fd, p[1])
+       pid, err = forkAndExecInChild(argv0p, argvp, envvp, dir, attr, p[1])
        if err != 0 {
        error:
                if p[0] >= 0 {
@@ -216,13 +246,14 @@ func forkExec(argv0 string, argv []string, envv []string, traceme bool, dir stri
 }
 
 // Combination of fork and exec, careful to be thread safe.
-func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []int) (pid int, err int) {
-       return forkExec(argv0, argv, envv, false, dir, fd)
+func ForkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err int) {
+       return forkExec(argv0, argv, attr)
 }
 
-// PtraceForkExec is like ForkExec, but starts the child in a traced state.
-func PtraceForkExec(argv0 string, argv []string, envv []string, dir string, fd []int) (pid int, err int) {
-       return forkExec(argv0, argv, envv, true, dir, fd)
+// StartProcess wraps ForkExec for package os.
+func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid, handle int, err int) {
+       pid, err = forkExec(argv0, argv, attr)
+       return pid, 0, err
 }
 
 // Ordinary exec.
@@ -233,12 +264,6 @@ func Exec(argv0 string, argv []string, envv []string) (err int) {
        return GetErrno()
 }
 
-// StartProcess wraps ForkExec for package os.
-func StartProcess(argv0 string, argv []string, envv []string, dir string, fd []int) (pid, handle int, err int) {
-       pid, err = forkExec(argv0, argv, envv, false, dir, fd)
-       return pid, 0, err
-}
-
 func Wait4(pid int, wstatus *WaitStatus, options int, rusage *Rusage) (wpid int, errno int) {
        var status int
        r := libc_wait4(Pid_t(pid), &status, options, rusage)
index 93db3462e8ac09ffb13fc901e95889b4e754bf09..20ae0a0bd25b93be236851ae31b02d973a40415d 100755 (executable)
@@ -276,6 +276,13 @@ if $havex; then
 fi
 
 # They all compile; now generate the code to call them.
+
+localname() {
+       # The package main has been renamed to __main__ when imported.
+       # Adjust its uses.
+       echo $1 | sed 's/^main\./__main__./'
+}
+
 {
        # test functions are named TestFoo
        # the grep -v eliminates methods and other special names
@@ -288,6 +295,9 @@ fi
                echo 'gotest: warning: no tests matching '$pattern in _gotest_.o $xofile 1>&2
                exit 2
        fi
+       # benchmarks are named BenchmarkFoo.
+       pattern='Benchmark([^a-z].*)?'
+       benchmarks=$($NM -p -v _gotest_.o $xofile | egrep ' T .*\.'$pattern'$' | grep -v '\..*\..*\.' | sed 's/.* //' | sed 's/.*\.\(.*\.\)/\1/')
 
        # package spec
        echo 'package main'
@@ -299,23 +309,48 @@ fi
        if $havex; then
                echo 'import "./_xtest_"'
        fi
-       if [ $package != "testing" ]; then
-               echo 'import "testing"'
-               echo 'import __regexp__ "regexp"' # rename in case tested package is called regexp
-       fi
+       echo 'import "testing"'
+       echo 'import __os__     "os"' # rename in case tested package is called os
+       echo 'import __regexp__ "regexp"' # rename in case tested package is called regexp
        # test array
        echo
        echo 'var tests = []testing.InternalTest {'
        for i in $tests
        do
-               echo '  { "'$i'", '$i' },'
+               j=$(localname $i)
+               echo '  {"'$i'", '$j'},'
        done
        echo '}'
-       # body
-       echo
-       echo 'func main() {'
-       echo '  testing.Main(__regexp__.MatchString, tests)'
+       # benchmark array
+       # The comment makes the multiline declaration
+       # gofmt-safe even when there are no benchmarks.
+       echo 'var benchmarks = []testing.InternalBenchmark{ //'
+       for i in $benchmarks
+       do
+               j=$(localname $i)
+               echo '  {"'$i'", '$j'},'
+       done
        echo '}'
+       # body
+       echo \
+'
+var matchPat string
+var matchRe *__regexp__.Regexp
+
+func matchString(pat, str string) (result bool, err __os__.Error) {
+       if matchRe == nil || matchPat != pat {
+               matchPat = pat
+               matchRe, err = __regexp__.Compile(matchPat)
+               if err != nil {
+                       return
+               }
+       }
+       return matchRe.MatchString(str), nil
+}
+
+func main() {
+       testing.Main(matchString, tests, benchmarks)
+}'
 }>_testmain.go
 
 case "x$dejagnu" in