From: Ian Lance Taylor Date: Fri, 20 May 2011 00:18:15 +0000 (+0000) Subject: Update to current version of Go library. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=9ff56c9570642711d5b7ab29920ecf5dbff14a27;p=gcc.git Update to current version of Go library. From-SVN: r173931 --- diff --git a/gcc/go/gofrontend/types.cc b/gcc/go/gofrontend/types.cc index e33b349b3c1..bf1921610c1 100644 --- a/gcc/go/gofrontend/types.cc +++ b/gcc/go/gofrontend/types.cc @@ -1432,7 +1432,7 @@ Type::methods_constructor(Gogo* gogo, Type* methods_type, p != smethods.end(); ++p) vals->push_back(this->method_constructor(gogo, method_type, p->first, - p->second)); + p->second, only_value_methods)); return Expression::make_slice_composite_literal(methods_type, vals, bloc); } @@ -1444,7 +1444,8 @@ Type::methods_constructor(Gogo* gogo, Type* methods_type, Expression* Type::method_constructor(Gogo*, Type* method_type, const std::string& method_name, - const Method* m) const + const Method* m, + bool only_value_methods) const { source_location bloc = BUILTINS_LOCATION; @@ -1487,6 +1488,25 @@ Type::method_constructor(Gogo*, Type* method_type, ++p; go_assert(p->field_name() == "typ"); + if (!only_value_methods && m->is_value_method()) + { + // This is a value method on a pointer type. Change the type of + // the method to use a pointer receiver. The implementation + // always uses a pointer receiver anyhow. + Type* rtype = mtype->receiver()->type(); + Type* prtype = Type::make_pointer_type(rtype); + Typed_identifier* receiver = + new Typed_identifier(mtype->receiver()->name(), prtype, + mtype->receiver()->location()); + mtype = Type::make_function_type(receiver, + (mtype->parameters() == NULL + ? NULL + : mtype->parameters()->copy()), + (mtype->results() == NULL + ? NULL + : mtype->results()->copy()), + mtype->location()); + } vals->push_back(Expression::make_type_descriptor(mtype, bloc)); ++p; @@ -2779,14 +2799,7 @@ Function_type::type_descriptor_params(Type* params_type, + (receiver != NULL ? 1 : 0)); if (receiver != NULL) - { - Type* rtype = receiver->type(); - // The receiver is always passed as a pointer. FIXME: Is this - // right? Should that fact affect the type descriptor? - if (rtype->points_to() == NULL) - rtype = Type::make_pointer_type(rtype); - vals->push_back(Expression::make_type_descriptor(rtype, bloc)); - } + vals->push_back(Expression::make_type_descriptor(receiver->type(), bloc)); if (params != NULL) { @@ -4822,9 +4835,10 @@ Array_type::make_array_type_descriptor_type() Type* uintptr_type = Type::lookup_integer_type("uintptr"); Struct_type* sf = - Type::make_builtin_struct_type(3, + Type::make_builtin_struct_type(4, "", tdt, "elem", ptdt, + "slice", ptdt, "len", uintptr_type); ret = Type::make_builtin_named_type("ArrayType", sf); @@ -4890,6 +4904,11 @@ Array_type::array_type_descriptor(Gogo* gogo, Named_type* name) go_assert(p->field_name() == "elem"); vals->push_back(Expression::make_type_descriptor(this->element_type_, bloc)); + ++p; + go_assert(p->field_name() == "slice"); + Type* slice_type = Type::make_array_type(this->element_type_, NULL); + vals->push_back(Expression::make_type_descriptor(slice_type, bloc)); + ++p; go_assert(p->field_name() == "len"); vals->push_back(Expression::make_cast(p->type(), this->length_, bloc)); @@ -5375,8 +5394,9 @@ Channel_type::do_make_expression_tree(Translate_context* context, Gogo* gogo = context->gogo(); tree channel_type = type_to_tree(this->get_backend(gogo)); - tree element_tree = type_to_tree(this->element_type_->get_backend(gogo)); - tree element_size_tree = size_in_bytes(element_tree); + Type* ptdt = Type::make_type_descriptor_ptr_type(); + tree element_type_descriptor = + this->element_type_->type_descriptor_pointer(gogo); tree bad_index = NULL_TREE; @@ -5402,8 +5422,8 @@ Channel_type::do_make_expression_tree(Translate_context* context, "__go_new_channel", 2, channel_type, - sizetype, - element_size_tree, + type_to_tree(ptdt->get_backend(gogo)), + element_type_descriptor, sizetype, expr_tree); if (ret == error_mark_node) @@ -6242,7 +6262,16 @@ Interface_type::do_reflection(Gogo* gogo, std::string* ret) const if (p != this->methods_->begin()) ret->append(";"); ret->push_back(' '); - ret->append(Gogo::unpack_hidden_name(p->name())); + if (!Gogo::is_hidden_name(p->name())) + ret->append(p->name()); + else + { + // This matches what the gc compiler does. + std::string prefix = Gogo::hidden_name_prefix(p->name()); + ret->append(prefix.substr(prefix.find('.') + 1)); + ret->push_back('.'); + ret->append(Gogo::unpack_hidden_name(p->name())); + } std::string sub = p->type()->reflection(gogo); go_assert(sub.compare(0, 4, "func") == 0); sub = sub.substr(4); diff --git a/gcc/go/gofrontend/types.h b/gcc/go/gofrontend/types.h index 913266b784c..3ada1b1f0f8 100644 --- a/gcc/go/gofrontend/types.h +++ b/gcc/go/gofrontend/types.h @@ -1044,7 +1044,7 @@ class Type // Build a composite literal for one method. Expression* method_constructor(Gogo*, Type* method_type, const std::string& name, - const Method*) const; + const Method*, bool only_value_methods) const; static tree build_receive_return_type(tree type); diff --git a/gcc/testsuite/go.test/test/convert.go b/gcc/testsuite/go.test/test/convert.go index e7361aa87f8..0a75663d064 100644 --- a/gcc/testsuite/go.test/test/convert.go +++ b/gcc/testsuite/go.test/test/convert.go @@ -8,7 +8,7 @@ package main import "reflect" -func typeof(x interface{}) string { return reflect.Typeof(x).String() } +func typeof(x interface{}) string { return reflect.TypeOf(x).String() } func f() int { return 0 } diff --git a/gcc/testsuite/go.test/test/fixedbugs/bug177.go b/gcc/testsuite/go.test/test/fixedbugs/bug177.go index 84ff59d2f5d..a120ad0abf2 100644 --- a/gcc/testsuite/go.test/test/fixedbugs/bug177.go +++ b/gcc/testsuite/go.test/test/fixedbugs/bug177.go @@ -5,23 +5,26 @@ // license that can be found in the LICENSE file. package main + import "reflect" -type S1 struct { i int } -type S2 struct { S1 } + +type S1 struct{ i int } +type S2 struct{ S1 } + func main() { - typ := reflect.Typeof(S2{}).(*reflect.StructType); - f := typ.Field(0); + typ := reflect.TypeOf(S2{}) + f := typ.Field(0) if f.Name != "S1" || f.Anonymous != true { - println("BUG: ", f.Name, f.Anonymous); - return; + println("BUG: ", f.Name, f.Anonymous) + return } - f, ok := typ.FieldByName("S1"); + f, ok := typ.FieldByName("S1") if !ok { - println("BUG: missing S1"); - return; + println("BUG: missing S1") + return } if !f.Anonymous { - println("BUG: S1 is not anonymous"); - return; + println("BUG: S1 is not anonymous") + return } } diff --git a/gcc/testsuite/go.test/test/fixedbugs/bug248.dir/bug2.go b/gcc/testsuite/go.test/test/fixedbugs/bug248.dir/bug2.go index 4ea187a4b12..b6c816a5cef 100644 --- a/gcc/testsuite/go.test/test/fixedbugs/bug248.dir/bug2.go +++ b/gcc/testsuite/go.test/test/fixedbugs/bug248.dir/bug2.go @@ -38,11 +38,11 @@ func main() { // meaning that reflect data for v0, v1 didn't get confused. // path is full (rooted) path name. check suffix for gc, prefix for gccgo - if s := reflect.Typeof(v0).PkgPath(); !strings.HasSuffix(s, "/bug0") && !strings.HasPrefix(s, "bug0") { + if s := reflect.TypeOf(v0).PkgPath(); !strings.HasSuffix(s, "/bug0") && !strings.HasPrefix(s, "bug0") { println("bad v0 path", len(s), s) panic("fail") } - if s := reflect.Typeof(v1).PkgPath(); !strings.HasSuffix(s, "/bug1") && !strings.HasPrefix(s, "bug1") { + if s := reflect.TypeOf(v1).PkgPath(); !strings.HasSuffix(s, "/bug1") && !strings.HasPrefix(s, "bug1") { println("bad v1 path", s) panic("fail") } diff --git a/gcc/testsuite/go.test/test/interface/fake.go b/gcc/testsuite/go.test/test/interface/fake.go index 5cf3be052ce..bdc5b9072c3 100644 --- a/gcc/testsuite/go.test/test/interface/fake.go +++ b/gcc/testsuite/go.test/test/interface/fake.go @@ -46,34 +46,34 @@ func main() { x.t = add("abc", "def") x.u = 1 x.v = 2 - x.w = 1<<28 - x.x = 2<<28 + x.w = 1 << 28 + x.x = 2 << 28 x.y = 0x12345678 x.z = x.y // check mem and string - v := reflect.NewValue(x) - i := v.(*reflect.StructValue).Field(0) - j := v.(*reflect.StructValue).Field(1) + v := reflect.ValueOf(x) + i := v.Field(0) + j := v.Field(1) assert(i.Interface() == j.Interface()) - s := v.(*reflect.StructValue).Field(2) - t := v.(*reflect.StructValue).Field(3) + s := v.Field(2) + t := v.Field(3) assert(s.Interface() == t.Interface()) // make sure different values are different. // make sure whole word is being compared, // not just a single byte. - i = v.(*reflect.StructValue).Field(4) - j = v.(*reflect.StructValue).Field(5) + i = v.Field(4) + j = v.Field(5) assert(i.Interface() != j.Interface()) - i = v.(*reflect.StructValue).Field(6) - j = v.(*reflect.StructValue).Field(7) + i = v.Field(6) + j = v.Field(7) assert(i.Interface() != j.Interface()) - i = v.(*reflect.StructValue).Field(8) - j = v.(*reflect.StructValue).Field(9) + i = v.Field(8) + j = v.Field(9) assert(i.Interface() == j.Interface()) } diff --git a/gcc/testsuite/go.test/test/ken/cplx3.go b/gcc/testsuite/go.test/test/ken/cplx3.go index 83acc15ff7c..fa6ff1d52ee 100644 --- a/gcc/testsuite/go.test/test/ken/cplx3.go +++ b/gcc/testsuite/go.test/test/ken/cplx3.go @@ -25,9 +25,9 @@ func main() { println(c) var a interface{} - switch c := reflect.NewValue(a).(type) { - case *reflect.ComplexValue: - v := c.Get() + switch c := reflect.ValueOf(a); c.Kind() { + case reflect.Complex64, reflect.Complex128: + v := c.Complex() _, _ = complex128(v), true } } diff --git a/libgo/MERGE b/libgo/MERGE index 237a0c2565e..9cee703c403 100644 --- a/libgo/MERGE +++ b/libgo/MERGE @@ -1,4 +1,4 @@ -f618e5e0991d +aea0ba6e5935 The first line of this file holds the Mercurial revision number of the last merge done from the master library sources. diff --git a/libgo/Makefile.am b/libgo/Makefile.am index 05f1c991638..ae6848f0932 100644 --- a/libgo/Makefile.am +++ b/libgo/Makefile.am @@ -248,7 +248,8 @@ toolexeclibgogo_DATA = \ go/printer.gox \ go/scanner.gox \ go/token.gox \ - go/typechecker.gox + go/typechecker.gox \ + go/types.gox toolexeclibgohashdir = $(toolexeclibgodir)/hash @@ -262,14 +263,19 @@ toolexeclibgohttpdir = $(toolexeclibgodir)/http toolexeclibgohttp_DATA = \ http/cgi.gox \ + http/fcgi.gox \ http/httptest.gox \ - http/pprof.gox + http/pprof.gox \ + http/spdy.gox toolexeclibgoimagedir = $(toolexeclibgodir)/image toolexeclibgoimage_DATA = \ + image/gif.gox \ image/jpeg.gox \ - image/png.gox + image/png.gox \ + image/tiff.gox \ + image/ycbcr.gox toolexeclibgoindexdir = $(toolexeclibgodir)/index @@ -303,6 +309,7 @@ endif toolexeclibgoos_DATA = \ $(os_inotify_gox) \ + os/user.gox \ os/signal.gox toolexeclibgopathdir = $(toolexeclibgodir)/path @@ -404,6 +411,7 @@ runtime_files = \ runtime/go-send-nb-big.c \ runtime/go-send-nb-small.c \ runtime/go-send-small.c \ + runtime/go-setenv.c \ runtime/go-signal.c \ runtime/go-strcmp.c \ runtime/go-string-to-byte-array.c \ @@ -560,6 +568,7 @@ go_http_files = \ go/http/persist.go \ go/http/request.go \ go/http/response.go \ + go/http/reverseproxy.go \ go/http/server.go \ go/http/status.go \ go/http/transfer.go \ @@ -656,8 +665,17 @@ go_net_newpollserver_file = go/net/newpollserver.go endif # !LIBGO_IS_LINUX endif # !LIBGO_IS_RTEMS +if LIBGO_IS_LINUX +go_net_cgo_file = go/net/cgo_linux.go +go_net_sock_file = go/net/sock_linux.go +else +go_net_cgo_file = go/net/cgo_bsd.go +go_net_sock_file = go/net/sock_bsd.go +endif + go_net_files = \ - go/net/cgo_stub.go \ + go/net/cgo_unix.go \ + $(go_net_cgo_file) \ go/net/dial.go \ go/net/dnsclient.go \ go/net/dnsconfig.go \ @@ -676,6 +694,7 @@ go_net_files = \ go/net/pipe.go \ go/net/port.go \ go/net/sock.go \ + $(go_net_sock_file) \ go/net/tcpsock.go \ go/net/udpsock.go \ go/net/unixsock.go @@ -1002,7 +1021,6 @@ go_crypto_subtle_files = \ go/crypto/subtle/constant_time.go go_crypto_tls_files = \ go/crypto/tls/alert.go \ - go/crypto/tls/ca_set.go \ go/crypto/tls/cipher_suites.go \ go/crypto/tls/common.go \ go/crypto/tls/conn.go \ @@ -1015,6 +1033,8 @@ go_crypto_tls_files = \ go_crypto_twofish_files = \ go/crypto/twofish/twofish.go go_crypto_x509_files = \ + go/crypto/x509/cert_pool.go \ + go/crypto/x509/verify.go \ go/crypto/x509/x509.go go_crypto_xtea_files = \ go/crypto/xtea/block.go \ @@ -1130,6 +1150,12 @@ go_go_typechecker_files = \ go/go/typechecker/type.go \ go/go/typechecker/typechecker.go \ go/go/typechecker/universe.go +go_go_types_files = \ + go/go/types/const.go \ + go/go/types/exportdata.go \ + go/go/types/gcimporter.go \ + go/go/types/types.go \ + go/go/types/universe.go go_hash_adler32_files = \ go/hash/adler32/adler32.go @@ -1143,21 +1169,39 @@ go_hash_fnv_files = \ go_http_cgi_files = \ go/http/cgi/child.go \ go/http/cgi/host.go +go_http_fcgi_files = \ + go/http/fcgi/child.go \ + go/http/fcgi/fcgi.go go_http_httptest_files = \ go/http/httptest/recorder.go \ go/http/httptest/server.go go_http_pprof_files = \ go/http/pprof/pprof.go +go_http_spdy_files = \ + go/http/spdy/protocol.go + +go_image_gif_files = \ + go/image/gif/reader.go go_image_jpeg_files = \ + go/image/jpeg/fdct.go \ go/image/jpeg/huffman.go \ go/image/jpeg/idct.go \ - go/image/jpeg/reader.go + go/image/jpeg/reader.go \ + go/image/jpeg/writer.go go_image_png_files = \ go/image/png/reader.go \ go/image/png/writer.go +go_image_tiff_files = \ + go/image/tiff/buffer.go \ + go/image/tiff/consts.go \ + go/image/tiff/reader.go + +go_image_ycbcr_files = \ + go/image/ycbcr/ycbcr.go + go_index_suffixarray_files = \ go/index/suffixarray/qsufsort.go \ go/index/suffixarray/suffixarray.go @@ -1167,6 +1211,7 @@ go_io_ioutil_files = \ go/io/ioutil/tempfile.go go_mime_multipart_files = \ + go/mime/multipart/formdata.go \ go/mime/multipart/multipart.go go_net_dict_files = \ @@ -1182,6 +1227,10 @@ go_net_textproto_files = \ go_os_inotify_files = \ go/os/inotify/inotify_linux.go +go_os_user_files = \ + go/os/user/user.go \ + go/os/user/lookup_unix.go + go_os_signal_files = \ go/os/signal/signal.go \ unix.go @@ -1485,21 +1534,28 @@ libgo_go_objs = \ go/scanner.lo \ go/token.lo \ go/typechecker.lo \ + go/types.lo \ hash/adler32.lo \ hash/crc32.lo \ hash/crc64.lo \ hash/fnv.lo \ http/cgi.lo \ + http/fcgi.lo \ http/httptest.lo \ http/pprof.lo \ + http/spdy.lo \ + image/gif.lo \ image/jpeg.lo \ image/png.lo \ + image/tiff.lo \ + image/ycbcr.lo \ index/suffixarray.lo \ io/ioutil.lo \ mime/multipart.lo \ net/dict.lo \ net/textproto.lo \ $(os_lib_inotify_lo) \ + os/user.lo \ os/signal.lo \ path/filepath.lo \ rpc/jsonrpc.lo \ @@ -1711,11 +1767,12 @@ html/check: $(CHECK_DEPS) @$(CHECK) .PHONY: html/check -http/http.lo: $(go_http_files) bufio.gox bytes.gox container/vector.gox \ - crypto/rand.gox crypto/tls.gox encoding/base64.gox fmt.gox \ - io.gox io/ioutil.gox log.gox mime.gox mime/multipart.gox \ - net.gox net/textproto.gox os.gox path.gox path/filepath.gox \ - sort.gox strconv.gox strings.gox sync.gox time.gox utf8.gox +http/http.lo: $(go_http_files) bufio.gox bytes.gox compress/gzip.gox \ + container/vector.gox crypto/rand.gox crypto/tls.gox \ + encoding/base64.gox fmt.gox io.gox io/ioutil.gox log.gox \ + mime.gox mime/multipart.gox net.gox net/textproto.gox os.gox \ + path.gox path/filepath.gox sort.gox strconv.gox strings.gox \ + sync.gox time.gox utf8.gox $(BUILDPACKAGE) http/check: $(CHECK_DEPS) @$(CHECK) @@ -1755,7 +1812,7 @@ math/check: $(CHECK_DEPS) @$(CHECK) .PHONY: math/check -mime/mime.lo: $(go_mime_files) bufio.gox bytes.gox os.gox strings.gox \ +mime/mime.lo: $(go_mime_files) bufio.gox bytes.gox fmt.gox os.gox strings.gox \ sync.gox unicode.gox $(BUILDPACKAGE) mime/check: $(CHECK_DEPS) @@ -1763,8 +1820,8 @@ mime/check: $(CHECK_DEPS) .PHONY: mime/check net/net.lo: $(go_net_files) bytes.gox fmt.gox io.gox os.gox rand.gox \ - reflect.gox strconv.gox strings.gox sync.gox syscall.gox \ - time.gox + reflect.gox sort.gox strconv.gox strings.gox sync.gox \ + syscall.gox time.gox $(BUILDPACKAGE) net/check: $(CHECK_DEPS) @$(CHECK_ON_REQUEST) @@ -1945,8 +2002,8 @@ xml/check: $(CHECK_DEPS) @$(CHECK) .PHONY: xml/check -archive/tar.lo: $(go_archive_tar_files) bytes.gox io.gox os.gox strconv.gox \ - strings.gox +archive/tar.lo: $(go_archive_tar_files) bytes.gox io.gox io/ioutil.gox os.gox \ + strconv.gox strings.gox $(BUILDPACKAGE) archive/tar/check: $(CHECK_DEPS) @$(MKDIR_P) archive/tar @@ -2148,8 +2205,7 @@ crypto/ripemd160/check: $(CHECK_DEPS) .PHONY: crypto/ripemd160/check crypto/rsa.lo: $(go_crypto_rsa_files) big.gox crypto.gox crypto/sha1.gox \ - crypto/subtle.gox encoding/hex.gox hash.gox io.gox os.gox \ - sync.gox + crypto/subtle.gox encoding/hex.gox hash.gox io.gox os.gox $(BUILDPACKAGE) crypto/rsa/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/rsa @@ -2184,13 +2240,13 @@ crypto/subtle/check: $(CHECK_DEPS) @$(CHECK) .PHONY: crypto/subtle/check -crypto/tls.lo: $(go_crypto_tls_files) big.gox bufio.gox bytes.gox \ - container/list.gox crypto.gox crypto/aes.gox crypto/cipher.gox \ - crypto/elliptic.gox crypto/hmac.gox crypto/md5.gox \ - crypto/rc4.gox crypto/rand.gox crypto/rsa.gox crypto/sha1.gox \ - crypto/subtle.gox crypto/rsa.gox crypto/sha1.gox \ - crypto/x509.gox encoding/pem.gox fmt.gox hash.gox io.gox \ - io/ioutil.gox net.gox os.gox strings.gox sync.gox time.gox +crypto/tls.lo: $(go_crypto_tls_files) big.gox bytes.gox crypto.gox \ + crypto/aes.gox crypto/cipher.gox crypto/elliptic.gox \ + crypto/hmac.gox crypto/md5.gox crypto/rand.gox crypto/rc4.gox \ + crypto/rsa.gox crypto/sha1.gox crypto/subtle.gox \ + crypto/x509.gox encoding/pem.gox hash.gox io.gox \ + io/ioutil.gox net.gox os.gox strconv.gox strings.gox sync.gox \ + time.gox $(BUILDPACKAGE) crypto/tls/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/tls @@ -2204,9 +2260,10 @@ crypto/twofish/check: $(CHECK_DEPS) @$(CHECK) .PHONY: crypto/twofish/check -crypto/x509.lo: $(go_crypto_x509_files) asn1.gox big.gox container/vector.gox \ - crypto.gox crypto/rsa.gox crypto/sha1.gox hash.gox os.gox \ - strings.gox time.gox +crypto/x509.lo: $(go_crypto_x509_files) asn1.gox big.gox bytes.gox \ + container/vector.gox crypto.gox crypto/rsa.gox \ + crypto/sha1.gox encoding/pem.gox hash.gox os.gox strings.gox \ + time.gox $(BUILDPACKAGE) crypto/x509/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/x509 @@ -2220,9 +2277,8 @@ crypto/xtea/check: $(CHECK_DEPS) @$(CHECK) .PHONY: crypto/xtea/check -crypto/openpgp/armor.lo: $(go_crypto_openpgp_armor_files) bytes.gox \ - crypto/openpgp/error.gox encoding/base64.gox \ - encoding/line.gox io.gox os.gox +crypto/openpgp/armor.lo: $(go_crypto_openpgp_armor_files) bufio.gox bytes.gox \ + crypto/openpgp/error.gox encoding/base64.gox io.gox os.gox $(BUILDPACKAGE) crypto/openpgp/armor/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/openpgp/armor @@ -2374,7 +2430,7 @@ exp/datafmt/check: $(CHECK_DEPS) @$(CHECK) .PHONY: exp/datafmt/check -exp/draw.lo: $(go_exp_draw_files) image.gox os.gox +exp/draw.lo: $(go_exp_draw_files) image.gox image/ycbcr.gox os.gox $(BUILDPACKAGE) exp/draw/check: $(CHECK_DEPS) @$(MKDIR_P) exp/draw @@ -2448,6 +2504,15 @@ go/typechecker/check: $(CHECK_DEPS) @$(CHECK) .PHONY: go/typechecker/check +go/types.lo: $(go_go_types_files) big.gox bufio.gox fmt.gox go/ast.gox \ + go/token.gox io.gox os.gox path/filepath.gox runtime.gox \ + scanner.gox strconv.gox strings.gox + $(BUILDPACKAGE) +go/types/check: $(CHECK_DEPS) + @$(MKDIR_P) go/types + @$(CHECK) +.PHONY: go/types/check + hash/adler32.lo: $(go_hash_adler32_files) hash.gox os.gox $(BUILDPACKAGE) hash/adler32/check: $(CHECK_DEPS) @@ -2476,15 +2541,25 @@ hash/fnv/check: $(CHECK_DEPS) @$(CHECK) .PHONY: hash/fnv/check -http/cgi.lo: $(go_http_cgi_files) bufio.gox bytes.gox encoding/line.gox \ - exec.gox fmt.gox http.gox io.gox io/ioutil.gox log.gox \ - os.gox path/filepath.gox regexp.gox strconv.gox strings.gox +http/cgi.lo: $(go_http_cgi_files) bufio.gox bytes.gox crypto/tls.gox \ + exec.gox fmt.gox http.gox net.gox io.gox io/ioutil.gox \ + log.gox os.gox path/filepath.gox regexp.gox strconv.gox \ + strings.gox $(BUILDPACKAGE) http/cgi/check: $(CHECK_DEPS) @$(MKDIR_P) http/cgi @$(CHECK) .PHONY: http/cgi/check +http/fcgi.lo: $(go_http_fcgi_files) bufio.gox bytes.gox encoding/binary.gox \ + fmt.gox http.gox http/cgi.gox io.gox net.gox os.gox sync.gox \ + time.gox + $(BUILDPACKAGE) +http/fcgi/check: $(CHECK_DEPS) + @$(MKDIR_P) http/fcgi + @$(CHECK) +.PHONY: http/fcgi/check + http/httptest.lo: $(go_http_httptest_files) bytes.gox crypto/rand.gox \ crypto/tls.gox fmt.gox http.gox net.gox os.gox time.gox $(BUILDPACKAGE) @@ -2493,15 +2568,33 @@ http/httptest/check: $(CHECK_DEPS) @$(CHECK) .PHONY: http/httptest/check -http/pprof.lo: $(go_http_pprof_files) bufio.gox fmt.gox http.gox os.gox \ - runtime.gox runtime/pprof.gox strconv.gox strings.gox +http/pprof.lo: $(go_http_pprof_files) bufio.gox bytes.gox fmt.gox http.gox \ + os.gox runtime.gox runtime/pprof.gox strconv.gox strings.gox $(BUILDPACKAGE) http/pprof/check: $(CHECK_DEPS) @$(MKDIR_P) http/pprof @$(CHECK) .PHONY: http/pprof/check -image/jpeg.lo: $(go_image_jpeg_files) bufio.gox image.gox io.gox os.gox +http/spdy.lo: $(go_http_spdy_files) bytes.gox compress/zlib.gox \ + encoding/binary.gox http.gox io.gox os.gox strconv.gox \ + strings.gox sync.gox + $(BUILDPACKAGE) +http/spdy/check: $(CHECK_DEPS) + @$(MKDIR_P) http/spdy + @$(CHECK) +.PHONY: http/spdy/check + +image/gif.lo: $(go_image_gif_files) bufio.gox compress/lzw.gox fmt.gox \ + image.gox io.gox os.gox + $(BUILDPACKAGE) +image/gif/check: $(CHECK_DEPS) + @$(MKDIR_P) image/gif + @$(CHECK) +.PHONY: image/gif/check + +image/jpeg.lo: $(go_image_jpeg_files) bufio.gox image.gox image/ycbcr.gox \ + io.gox os.gox $(BUILDPACKAGE) image/jpeg/check: $(CHECK_DEPS) @$(MKDIR_P) image/jpeg @@ -2516,6 +2609,21 @@ image/png/check: $(CHECK_DEPS) @$(CHECK) .PHONY: image/png/check +image/tiff.lo: $(go_image_tiff_files) compress/lzw.gox compress/zlib.gox \ + encoding/binary.gox image.gox io.gox io/ioutil.gox os.gox + $(BUILDPACKAGE) +image/tiff/check: $(CHECK_DEPS) + @$(MKDIR_P) image/tiff + @$(CHECK) +.PHONY: image/tiff/check + +image/ycbcr.lo: $(go_image_ycbcr_files) image.gox + $(BUILDPACKAGE) +image/ycbcr/check: $(CHECK_DEPS) + @$(MKDIR_P) image/ycbcr + @$(CHECK) +.PHONY: image/ycbcr/check + index/suffixarray.lo: $(go_index_suffixarray_files) bytes.gox regexp.gox \ sort.gox $(BUILDPACKAGE) @@ -2532,8 +2640,9 @@ io/ioutil/check: $(CHECK_DEPS) @$(CHECK) .PHONY: io/ioutil/check -mime/multipart.lo: $(go_mime_multipart_files) bufio.gox bytes.gox io.gox \ - mime.gox net/textproto.gox os.gox regexp.gox strings.gox +mime/multipart.lo: $(go_mime_multipart_files) bufio.gox bytes.gox fmt.gox \ + io.gox io/ioutil.gox mime.gox net/textproto.gox os.gox \ + regexp.gox $(BUILDPACKAGE) mime/multipart/check: $(CHECK_DEPS) @$(MKDIR_P) mime/multipart @@ -2560,6 +2669,14 @@ os/inotify/check: $(CHECK_DEPS) @$(CHECK) .PHONY: os/inotify/check +os/user.lo: $(go_os_user_files) fmt.gox os.gox runtime.gox strconv.gox \ + strings.gox syscall.gox + $(BUILDPACKAGE) +os/user/check: $(CHECK_DEPS) + @$(MKDIR_P) os/user + @$(CHECK) +.PHONY: os/user/check + os/signal.lo: $(go_os_signal_files) runtime.gox strconv.gox $(BUILDPACKAGE) os/signal/check: $(CHECK_DEPS) @@ -2886,6 +3003,8 @@ go/token.gox: go/token.lo $(BUILDGOX) go/typechecker.gox: go/typechecker.lo $(BUILDGOX) +go/types.gox: go/types.lo + $(BUILDGOX) hash/adler32.gox: hash/adler32.lo $(BUILDGOX) @@ -2898,15 +3017,25 @@ hash/fnv.gox: hash/fnv.lo http/cgi.gox: http/cgi.lo $(BUILDGOX) +http/fcgi.gox: http/fcgi.lo + $(BUILDGOX) http/httptest.gox: http/httptest.lo $(BUILDGOX) http/pprof.gox: http/pprof.lo $(BUILDGOX) +http/spdy.gox: http/spdy.lo + $(BUILDGOX) +image/gif.gox: image/gif.lo + $(BUILDGOX) image/jpeg.gox: image/jpeg.lo $(BUILDGOX) image/png.gox: image/png.lo $(BUILDGOX) +image/tiff.gox: image/tiff.lo + $(BUILDGOX) +image/ycbcr.gox: image/ycbcr.lo + $(BUILDGOX) index/suffixarray.gox: index/suffixarray.lo $(BUILDGOX) @@ -2924,6 +3053,8 @@ net/textproto.gox: net/textproto.lo os/inotify.gox: os/inotify.lo $(BUILDGOX) +os/user.gox: os/user.lo + $(BUILDGOX) os/signal.gox: os/signal.lo $(BUILDGOX) @@ -3054,22 +3185,30 @@ TEST_PACKAGES = \ exp/datafmt/check \ exp/draw/check \ exp/eval/check \ + go/ast/check \ go/parser/check \ go/printer/check \ go/scanner/check \ go/token/check \ go/typechecker/check \ + $(go_types_check_omitted_since_it_calls_6g) \ hash/adler32/check \ hash/crc32/check \ hash/crc64/check \ hash/fnv/check \ http/cgi/check \ + http/fcgi/check \ + http/spdy/check \ + image/jpeg/check \ image/png/check \ + image/tiff/check \ + image/ycbcr/check \ index/suffixarray/check \ io/ioutil/check \ mime/multipart/check \ net/textproto/check \ $(os_inotify_check) \ + os/user/check \ os/signal/check \ path/filepath/check \ rpc/jsonrpc/check \ diff --git a/libgo/Makefile.in b/libgo/Makefile.in index cf84f730cdd..10d0a4e4e20 100644 --- a/libgo/Makefile.in +++ b/libgo/Makefile.in @@ -155,15 +155,17 @@ am__DEPENDENCIES_2 = asn1/asn1.lo big/big.lo bufio/bufio.lo \ encoding/binary.lo encoding/git85.lo encoding/hex.lo \ encoding/line.lo encoding/pem.lo exp/datafmt.lo exp/draw.lo \ exp/eval.lo go/ast.lo go/doc.lo go/parser.lo go/printer.lo \ - go/scanner.lo go/token.lo go/typechecker.lo hash/adler32.lo \ - hash/crc32.lo hash/crc64.lo hash/fnv.lo http/cgi.lo \ - http/httptest.lo http/pprof.lo image/jpeg.lo image/png.lo \ - index/suffixarray.lo io/ioutil.lo mime/multipart.lo \ - net/dict.lo net/textproto.lo $(am__DEPENDENCIES_1) \ - os/signal.lo path/filepath.lo rpc/jsonrpc.lo runtime/debug.lo \ - runtime/pprof.lo sync/atomic.lo sync/atomic_c.lo \ - syscalls/syscall.lo syscalls/errno.lo testing/testing.lo \ - testing/iotest.lo testing/quick.lo testing/script.lo + go/scanner.lo go/token.lo go/typechecker.lo go/types.lo \ + hash/adler32.lo hash/crc32.lo hash/crc64.lo hash/fnv.lo \ + http/cgi.lo http/fcgi.lo http/httptest.lo http/pprof.lo \ + http/spdy.lo image/gif.lo image/jpeg.lo image/png.lo \ + image/tiff.lo image/ycbcr.lo index/suffixarray.lo io/ioutil.lo \ + mime/multipart.lo net/dict.lo net/textproto.lo \ + $(am__DEPENDENCIES_1) os/user.lo os/signal.lo path/filepath.lo \ + rpc/jsonrpc.lo runtime/debug.lo runtime/pprof.lo \ + sync/atomic.lo sync/atomic_c.lo syscalls/syscall.lo \ + syscalls/errno.lo testing/testing.lo testing/iotest.lo \ + testing/quick.lo testing/script.lo libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \ $(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \ $(am__DEPENDENCIES_1) @@ -196,7 +198,7 @@ am__libgo_la_SOURCES_DIST = runtime/go-append.c runtime/go-assert.c \ runtime/go-select.c runtime/go-semacquire.c \ runtime/go-send-big.c runtime/go-send-nb-big.c \ runtime/go-send-nb-small.c runtime/go-send-small.c \ - runtime/go-signal.c runtime/go-strcmp.c \ + runtime/go-setenv.c runtime/go-signal.c runtime/go-strcmp.c \ runtime/go-string-to-byte-array.c \ runtime/go-string-to-int-array.c runtime/go-strplus.c \ runtime/go-strslice.c runtime/go-trampoline.c \ @@ -233,7 +235,7 @@ am__objects_3 = go-append.lo go-assert.lo go-assert-interface.lo \ go-reflect-chan.lo go-reflect-map.lo go-rune.lo \ go-runtime-error.lo go-sched.lo go-select.lo go-semacquire.lo \ go-send-big.lo go-send-nb-big.lo go-send-nb-small.lo \ - go-send-small.lo go-signal.lo go-strcmp.lo \ + go-send-small.lo go-setenv.lo go-signal.lo go-strcmp.lo \ go-string-to-byte-array.lo go-string-to-int-array.lo \ go-strplus.lo go-strslice.lo go-trampoline.lo go-type-eface.lo \ go-type-error.lo go-type-identity.lo go-type-interface.lo \ @@ -689,7 +691,8 @@ toolexeclibgogo_DATA = \ go/printer.gox \ go/scanner.gox \ go/token.gox \ - go/typechecker.gox + go/typechecker.gox \ + go/types.gox toolexeclibgohashdir = $(toolexeclibgodir)/hash toolexeclibgohash_DATA = \ @@ -701,13 +704,18 @@ toolexeclibgohash_DATA = \ toolexeclibgohttpdir = $(toolexeclibgodir)/http toolexeclibgohttp_DATA = \ http/cgi.gox \ + http/fcgi.gox \ http/httptest.gox \ - http/pprof.gox + http/pprof.gox \ + http/spdy.gox toolexeclibgoimagedir = $(toolexeclibgodir)/image toolexeclibgoimage_DATA = \ + image/gif.gox \ image/jpeg.gox \ - image/png.gox + image/png.gox \ + image/tiff.gox \ + image/ycbcr.gox toolexeclibgoindexdir = $(toolexeclibgodir)/index toolexeclibgoindex_DATA = \ @@ -733,6 +741,7 @@ toolexeclibgoosdir = $(toolexeclibgodir)/os @LIBGO_IS_LINUX_TRUE@os_inotify_gox = toolexeclibgoos_DATA = \ $(os_inotify_gox) \ + os/user.gox \ os/signal.gox toolexeclibgopathdir = $(toolexeclibgodir)/path @@ -821,6 +830,7 @@ runtime_files = \ runtime/go-send-nb-big.c \ runtime/go-send-nb-small.c \ runtime/go-send-small.c \ + runtime/go-setenv.c \ runtime/go-signal.c \ runtime/go-strcmp.c \ runtime/go-string-to-byte-array.c \ @@ -952,6 +962,7 @@ go_http_files = \ go/http/persist.go \ go/http/request.go \ go/http/response.go \ + go/http/reverseproxy.go \ go/http/server.go \ go/http/status.go \ go/http/transfer.go \ @@ -1041,8 +1052,13 @@ go_mime_files = \ @LIBGO_IS_LINUX_FALSE@@LIBGO_IS_RTEMS_FALSE@go_net_newpollserver_file = go/net/newpollserver.go @LIBGO_IS_LINUX_TRUE@@LIBGO_IS_RTEMS_FALSE@go_net_newpollserver_file = go/net/newpollserver.go @LIBGO_IS_RTEMS_TRUE@go_net_newpollserver_file = go/net/newpollserver_rtems.go +@LIBGO_IS_LINUX_FALSE@go_net_cgo_file = go/net/cgo_bsd.go +@LIBGO_IS_LINUX_TRUE@go_net_cgo_file = go/net/cgo_linux.go +@LIBGO_IS_LINUX_FALSE@go_net_sock_file = go/net/sock_bsd.go +@LIBGO_IS_LINUX_TRUE@go_net_sock_file = go/net/sock_linux.go go_net_files = \ - go/net/cgo_stub.go \ + go/net/cgo_unix.go \ + $(go_net_cgo_file) \ go/net/dial.go \ go/net/dnsclient.go \ go/net/dnsconfig.go \ @@ -1061,6 +1077,7 @@ go_net_files = \ go/net/pipe.go \ go/net/port.go \ go/net/sock.go \ + $(go_net_sock_file) \ go/net/tcpsock.go \ go/net/udpsock.go \ go/net/unixsock.go @@ -1365,7 +1382,6 @@ go_crypto_subtle_files = \ go_crypto_tls_files = \ go/crypto/tls/alert.go \ - go/crypto/tls/ca_set.go \ go/crypto/tls/cipher_suites.go \ go/crypto/tls/common.go \ go/crypto/tls/conn.go \ @@ -1380,6 +1396,8 @@ go_crypto_twofish_files = \ go/crypto/twofish/twofish.go go_crypto_x509_files = \ + go/crypto/x509/cert_pool.go \ + go/crypto/x509/verify.go \ go/crypto/x509/x509.go go_crypto_xtea_files = \ @@ -1519,6 +1537,13 @@ go_go_typechecker_files = \ go/go/typechecker/typechecker.go \ go/go/typechecker/universe.go +go_go_types_files = \ + go/go/types/const.go \ + go/go/types/exportdata.go \ + go/go/types/gcimporter.go \ + go/go/types/types.go \ + go/go/types/universe.go + go_hash_adler32_files = \ go/hash/adler32/adler32.go @@ -1535,6 +1560,10 @@ go_http_cgi_files = \ go/http/cgi/child.go \ go/http/cgi/host.go +go_http_fcgi_files = \ + go/http/fcgi/child.go \ + go/http/fcgi/fcgi.go + go_http_httptest_files = \ go/http/httptest/recorder.go \ go/http/httptest/server.go @@ -1542,15 +1571,31 @@ go_http_httptest_files = \ go_http_pprof_files = \ go/http/pprof/pprof.go +go_http_spdy_files = \ + go/http/spdy/protocol.go + +go_image_gif_files = \ + go/image/gif/reader.go + go_image_jpeg_files = \ + go/image/jpeg/fdct.go \ go/image/jpeg/huffman.go \ go/image/jpeg/idct.go \ - go/image/jpeg/reader.go + go/image/jpeg/reader.go \ + go/image/jpeg/writer.go go_image_png_files = \ go/image/png/reader.go \ go/image/png/writer.go +go_image_tiff_files = \ + go/image/tiff/buffer.go \ + go/image/tiff/consts.go \ + go/image/tiff/reader.go + +go_image_ycbcr_files = \ + go/image/ycbcr/ycbcr.go + go_index_suffixarray_files = \ go/index/suffixarray/qsufsort.go \ go/index/suffixarray/suffixarray.go @@ -1560,6 +1605,7 @@ go_io_ioutil_files = \ go/io/ioutil/tempfile.go go_mime_multipart_files = \ + go/mime/multipart/formdata.go \ go/mime/multipart/multipart.go go_net_dict_files = \ @@ -1575,6 +1621,10 @@ go_net_textproto_files = \ go_os_inotify_files = \ go/os/inotify/inotify_linux.go +go_os_user_files = \ + go/os/user/user.go \ + go/os/user/lookup_unix.go + go_os_signal_files = \ go/os/signal/signal.go \ unix.go @@ -1816,21 +1866,28 @@ libgo_go_objs = \ go/scanner.lo \ go/token.lo \ go/typechecker.lo \ + go/types.lo \ hash/adler32.lo \ hash/crc32.lo \ hash/crc64.lo \ hash/fnv.lo \ http/cgi.lo \ + http/fcgi.lo \ http/httptest.lo \ http/pprof.lo \ + http/spdy.lo \ + image/gif.lo \ image/jpeg.lo \ image/png.lo \ + image/tiff.lo \ + image/ycbcr.lo \ index/suffixarray.lo \ io/ioutil.lo \ mime/multipart.lo \ net/dict.lo \ net/textproto.lo \ $(os_lib_inotify_lo) \ + os/user.lo \ os/signal.lo \ path/filepath.lo \ rpc/jsonrpc.lo \ @@ -2052,22 +2109,30 @@ TEST_PACKAGES = \ exp/datafmt/check \ exp/draw/check \ exp/eval/check \ + go/ast/check \ go/parser/check \ go/printer/check \ go/scanner/check \ go/token/check \ go/typechecker/check \ + $(go_types_check_omitted_since_it_calls_6g) \ hash/adler32/check \ hash/crc32/check \ hash/crc64/check \ hash/fnv/check \ http/cgi/check \ + http/fcgi/check \ + http/spdy/check \ + image/jpeg/check \ image/png/check \ + image/tiff/check \ + image/ycbcr/check \ index/suffixarray/check \ io/ioutil/check \ mime/multipart/check \ net/textproto/check \ $(os_inotify_check) \ + os/user/check \ os/signal/check \ path/filepath/check \ rpc/jsonrpc/check \ @@ -2270,6 +2335,7 @@ distclean-compile: @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-send-nb-big.Plo@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-send-nb-small.Plo@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-send-small.Plo@am__quote@ +@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-setenv.Plo@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-signal.Plo@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-strcmp.Plo@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/go-string-to-byte-array.Plo@am__quote@ @@ -2750,6 +2816,13 @@ go-send-small.lo: runtime/go-send-small.c @AMDEP_TRUE@@am__fastdepCC_FALSE@ DEPDIR=$(DEPDIR) $(CCDEPMODE) $(depcomp) @AMDEPBACKSLASH@ @am__fastdepCC_FALSE@ $(LIBTOOL) --tag=CC $(AM_LIBTOOLFLAGS) $(LIBTOOLFLAGS) --mode=compile $(CC) $(DEFS) $(DEFAULT_INCLUDES) $(INCLUDES) $(AM_CPPFLAGS) $(CPPFLAGS) $(AM_CFLAGS) $(CFLAGS) -c -o go-send-small.lo `test -f 'runtime/go-send-small.c' || echo '$(srcdir)/'`runtime/go-send-small.c +go-setenv.lo: runtime/go-setenv.c +@am__fastdepCC_TRUE@ $(LIBTOOL) --tag=CC $(AM_LIBTOOLFLAGS) $(LIBTOOLFLAGS) --mode=compile $(CC) $(DEFS) $(DEFAULT_INCLUDES) $(INCLUDES) $(AM_CPPFLAGS) $(CPPFLAGS) $(AM_CFLAGS) $(CFLAGS) -MT go-setenv.lo -MD -MP -MF $(DEPDIR)/go-setenv.Tpo -c -o go-setenv.lo `test -f 'runtime/go-setenv.c' || echo '$(srcdir)/'`runtime/go-setenv.c +@am__fastdepCC_TRUE@ $(am__mv) $(DEPDIR)/go-setenv.Tpo $(DEPDIR)/go-setenv.Plo +@AMDEP_TRUE@@am__fastdepCC_FALSE@ source='runtime/go-setenv.c' object='go-setenv.lo' libtool=yes @AMDEPBACKSLASH@ +@AMDEP_TRUE@@am__fastdepCC_FALSE@ DEPDIR=$(DEPDIR) $(CCDEPMODE) $(depcomp) @AMDEPBACKSLASH@ +@am__fastdepCC_FALSE@ $(LIBTOOL) --tag=CC $(AM_LIBTOOLFLAGS) $(LIBTOOLFLAGS) --mode=compile $(CC) $(DEFS) $(DEFAULT_INCLUDES) $(INCLUDES) $(AM_CPPFLAGS) $(CPPFLAGS) $(AM_CFLAGS) $(CFLAGS) -c -o go-setenv.lo `test -f 'runtime/go-setenv.c' || echo '$(srcdir)/'`runtime/go-setenv.c + go-signal.lo: runtime/go-signal.c @am__fastdepCC_TRUE@ $(LIBTOOL) --tag=CC $(AM_LIBTOOLFLAGS) $(LIBTOOLFLAGS) --mode=compile $(CC) $(DEFS) $(DEFAULT_INCLUDES) $(INCLUDES) $(AM_CPPFLAGS) $(CPPFLAGS) $(AM_CFLAGS) $(CFLAGS) -MT go-signal.lo -MD -MP -MF $(DEPDIR)/go-signal.Tpo -c -o go-signal.lo `test -f 'runtime/go-signal.c' || echo '$(srcdir)/'`runtime/go-signal.c @am__fastdepCC_TRUE@ $(am__mv) $(DEPDIR)/go-signal.Tpo $(DEPDIR)/go-signal.Plo @@ -4114,11 +4187,12 @@ html/check: $(CHECK_DEPS) @$(CHECK) .PHONY: html/check -http/http.lo: $(go_http_files) bufio.gox bytes.gox container/vector.gox \ - crypto/rand.gox crypto/tls.gox encoding/base64.gox fmt.gox \ - io.gox io/ioutil.gox log.gox mime.gox mime/multipart.gox \ - net.gox net/textproto.gox os.gox path.gox path/filepath.gox \ - sort.gox strconv.gox strings.gox sync.gox time.gox utf8.gox +http/http.lo: $(go_http_files) bufio.gox bytes.gox compress/gzip.gox \ + container/vector.gox crypto/rand.gox crypto/tls.gox \ + encoding/base64.gox fmt.gox io.gox io/ioutil.gox log.gox \ + mime.gox mime/multipart.gox net.gox net/textproto.gox os.gox \ + path.gox path/filepath.gox sort.gox strconv.gox strings.gox \ + sync.gox time.gox utf8.gox $(BUILDPACKAGE) http/check: $(CHECK_DEPS) @$(CHECK) @@ -4158,7 +4232,7 @@ math/check: $(CHECK_DEPS) @$(CHECK) .PHONY: math/check -mime/mime.lo: $(go_mime_files) bufio.gox bytes.gox os.gox strings.gox \ +mime/mime.lo: $(go_mime_files) bufio.gox bytes.gox fmt.gox os.gox strings.gox \ sync.gox unicode.gox $(BUILDPACKAGE) mime/check: $(CHECK_DEPS) @@ -4166,8 +4240,8 @@ mime/check: $(CHECK_DEPS) .PHONY: mime/check net/net.lo: $(go_net_files) bytes.gox fmt.gox io.gox os.gox rand.gox \ - reflect.gox strconv.gox strings.gox sync.gox syscall.gox \ - time.gox + reflect.gox sort.gox strconv.gox strings.gox sync.gox \ + syscall.gox time.gox $(BUILDPACKAGE) net/check: $(CHECK_DEPS) @$(CHECK_ON_REQUEST) @@ -4348,8 +4422,8 @@ xml/check: $(CHECK_DEPS) @$(CHECK) .PHONY: xml/check -archive/tar.lo: $(go_archive_tar_files) bytes.gox io.gox os.gox strconv.gox \ - strings.gox +archive/tar.lo: $(go_archive_tar_files) bytes.gox io.gox io/ioutil.gox os.gox \ + strconv.gox strings.gox $(BUILDPACKAGE) archive/tar/check: $(CHECK_DEPS) @$(MKDIR_P) archive/tar @@ -4551,8 +4625,7 @@ crypto/ripemd160/check: $(CHECK_DEPS) .PHONY: crypto/ripemd160/check crypto/rsa.lo: $(go_crypto_rsa_files) big.gox crypto.gox crypto/sha1.gox \ - crypto/subtle.gox encoding/hex.gox hash.gox io.gox os.gox \ - sync.gox + crypto/subtle.gox encoding/hex.gox hash.gox io.gox os.gox $(BUILDPACKAGE) crypto/rsa/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/rsa @@ -4587,13 +4660,13 @@ crypto/subtle/check: $(CHECK_DEPS) @$(CHECK) .PHONY: crypto/subtle/check -crypto/tls.lo: $(go_crypto_tls_files) big.gox bufio.gox bytes.gox \ - container/list.gox crypto.gox crypto/aes.gox crypto/cipher.gox \ - crypto/elliptic.gox crypto/hmac.gox crypto/md5.gox \ - crypto/rc4.gox crypto/rand.gox crypto/rsa.gox crypto/sha1.gox \ - crypto/subtle.gox crypto/rsa.gox crypto/sha1.gox \ - crypto/x509.gox encoding/pem.gox fmt.gox hash.gox io.gox \ - io/ioutil.gox net.gox os.gox strings.gox sync.gox time.gox +crypto/tls.lo: $(go_crypto_tls_files) big.gox bytes.gox crypto.gox \ + crypto/aes.gox crypto/cipher.gox crypto/elliptic.gox \ + crypto/hmac.gox crypto/md5.gox crypto/rand.gox crypto/rc4.gox \ + crypto/rsa.gox crypto/sha1.gox crypto/subtle.gox \ + crypto/x509.gox encoding/pem.gox hash.gox io.gox \ + io/ioutil.gox net.gox os.gox strconv.gox strings.gox sync.gox \ + time.gox $(BUILDPACKAGE) crypto/tls/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/tls @@ -4607,9 +4680,10 @@ crypto/twofish/check: $(CHECK_DEPS) @$(CHECK) .PHONY: crypto/twofish/check -crypto/x509.lo: $(go_crypto_x509_files) asn1.gox big.gox container/vector.gox \ - crypto.gox crypto/rsa.gox crypto/sha1.gox hash.gox os.gox \ - strings.gox time.gox +crypto/x509.lo: $(go_crypto_x509_files) asn1.gox big.gox bytes.gox \ + container/vector.gox crypto.gox crypto/rsa.gox \ + crypto/sha1.gox encoding/pem.gox hash.gox os.gox strings.gox \ + time.gox $(BUILDPACKAGE) crypto/x509/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/x509 @@ -4623,9 +4697,8 @@ crypto/xtea/check: $(CHECK_DEPS) @$(CHECK) .PHONY: crypto/xtea/check -crypto/openpgp/armor.lo: $(go_crypto_openpgp_armor_files) bytes.gox \ - crypto/openpgp/error.gox encoding/base64.gox \ - encoding/line.gox io.gox os.gox +crypto/openpgp/armor.lo: $(go_crypto_openpgp_armor_files) bufio.gox bytes.gox \ + crypto/openpgp/error.gox encoding/base64.gox io.gox os.gox $(BUILDPACKAGE) crypto/openpgp/armor/check: $(CHECK_DEPS) @$(MKDIR_P) crypto/openpgp/armor @@ -4777,7 +4850,7 @@ exp/datafmt/check: $(CHECK_DEPS) @$(CHECK) .PHONY: exp/datafmt/check -exp/draw.lo: $(go_exp_draw_files) image.gox os.gox +exp/draw.lo: $(go_exp_draw_files) image.gox image/ycbcr.gox os.gox $(BUILDPACKAGE) exp/draw/check: $(CHECK_DEPS) @$(MKDIR_P) exp/draw @@ -4851,6 +4924,15 @@ go/typechecker/check: $(CHECK_DEPS) @$(CHECK) .PHONY: go/typechecker/check +go/types.lo: $(go_go_types_files) big.gox bufio.gox fmt.gox go/ast.gox \ + go/token.gox io.gox os.gox path/filepath.gox runtime.gox \ + scanner.gox strconv.gox strings.gox + $(BUILDPACKAGE) +go/types/check: $(CHECK_DEPS) + @$(MKDIR_P) go/types + @$(CHECK) +.PHONY: go/types/check + hash/adler32.lo: $(go_hash_adler32_files) hash.gox os.gox $(BUILDPACKAGE) hash/adler32/check: $(CHECK_DEPS) @@ -4879,15 +4961,25 @@ hash/fnv/check: $(CHECK_DEPS) @$(CHECK) .PHONY: hash/fnv/check -http/cgi.lo: $(go_http_cgi_files) bufio.gox bytes.gox encoding/line.gox \ - exec.gox fmt.gox http.gox io.gox io/ioutil.gox log.gox \ - os.gox path/filepath.gox regexp.gox strconv.gox strings.gox +http/cgi.lo: $(go_http_cgi_files) bufio.gox bytes.gox crypto/tls.gox \ + exec.gox fmt.gox http.gox net.gox io.gox io/ioutil.gox \ + log.gox os.gox path/filepath.gox regexp.gox strconv.gox \ + strings.gox $(BUILDPACKAGE) http/cgi/check: $(CHECK_DEPS) @$(MKDIR_P) http/cgi @$(CHECK) .PHONY: http/cgi/check +http/fcgi.lo: $(go_http_fcgi_files) bufio.gox bytes.gox encoding/binary.gox \ + fmt.gox http.gox http/cgi.gox io.gox net.gox os.gox sync.gox \ + time.gox + $(BUILDPACKAGE) +http/fcgi/check: $(CHECK_DEPS) + @$(MKDIR_P) http/fcgi + @$(CHECK) +.PHONY: http/fcgi/check + http/httptest.lo: $(go_http_httptest_files) bytes.gox crypto/rand.gox \ crypto/tls.gox fmt.gox http.gox net.gox os.gox time.gox $(BUILDPACKAGE) @@ -4896,15 +4988,33 @@ http/httptest/check: $(CHECK_DEPS) @$(CHECK) .PHONY: http/httptest/check -http/pprof.lo: $(go_http_pprof_files) bufio.gox fmt.gox http.gox os.gox \ - runtime.gox runtime/pprof.gox strconv.gox strings.gox +http/pprof.lo: $(go_http_pprof_files) bufio.gox bytes.gox fmt.gox http.gox \ + os.gox runtime.gox runtime/pprof.gox strconv.gox strings.gox $(BUILDPACKAGE) http/pprof/check: $(CHECK_DEPS) @$(MKDIR_P) http/pprof @$(CHECK) .PHONY: http/pprof/check -image/jpeg.lo: $(go_image_jpeg_files) bufio.gox image.gox io.gox os.gox +http/spdy.lo: $(go_http_spdy_files) bytes.gox compress/zlib.gox \ + encoding/binary.gox http.gox io.gox os.gox strconv.gox \ + strings.gox sync.gox + $(BUILDPACKAGE) +http/spdy/check: $(CHECK_DEPS) + @$(MKDIR_P) http/spdy + @$(CHECK) +.PHONY: http/spdy/check + +image/gif.lo: $(go_image_gif_files) bufio.gox compress/lzw.gox fmt.gox \ + image.gox io.gox os.gox + $(BUILDPACKAGE) +image/gif/check: $(CHECK_DEPS) + @$(MKDIR_P) image/gif + @$(CHECK) +.PHONY: image/gif/check + +image/jpeg.lo: $(go_image_jpeg_files) bufio.gox image.gox image/ycbcr.gox \ + io.gox os.gox $(BUILDPACKAGE) image/jpeg/check: $(CHECK_DEPS) @$(MKDIR_P) image/jpeg @@ -4919,6 +5029,21 @@ image/png/check: $(CHECK_DEPS) @$(CHECK) .PHONY: image/png/check +image/tiff.lo: $(go_image_tiff_files) compress/lzw.gox compress/zlib.gox \ + encoding/binary.gox image.gox io.gox io/ioutil.gox os.gox + $(BUILDPACKAGE) +image/tiff/check: $(CHECK_DEPS) + @$(MKDIR_P) image/tiff + @$(CHECK) +.PHONY: image/tiff/check + +image/ycbcr.lo: $(go_image_ycbcr_files) image.gox + $(BUILDPACKAGE) +image/ycbcr/check: $(CHECK_DEPS) + @$(MKDIR_P) image/ycbcr + @$(CHECK) +.PHONY: image/ycbcr/check + index/suffixarray.lo: $(go_index_suffixarray_files) bytes.gox regexp.gox \ sort.gox $(BUILDPACKAGE) @@ -4935,8 +5060,9 @@ io/ioutil/check: $(CHECK_DEPS) @$(CHECK) .PHONY: io/ioutil/check -mime/multipart.lo: $(go_mime_multipart_files) bufio.gox bytes.gox io.gox \ - mime.gox net/textproto.gox os.gox regexp.gox strings.gox +mime/multipart.lo: $(go_mime_multipart_files) bufio.gox bytes.gox fmt.gox \ + io.gox io/ioutil.gox mime.gox net/textproto.gox os.gox \ + regexp.gox $(BUILDPACKAGE) mime/multipart/check: $(CHECK_DEPS) @$(MKDIR_P) mime/multipart @@ -4963,6 +5089,14 @@ os/inotify/check: $(CHECK_DEPS) @$(CHECK) .PHONY: os/inotify/check +os/user.lo: $(go_os_user_files) fmt.gox os.gox runtime.gox strconv.gox \ + strings.gox syscall.gox + $(BUILDPACKAGE) +os/user/check: $(CHECK_DEPS) + @$(MKDIR_P) os/user + @$(CHECK) +.PHONY: os/user/check + os/signal.lo: $(go_os_signal_files) runtime.gox strconv.gox $(BUILDPACKAGE) os/signal/check: $(CHECK_DEPS) @@ -5284,6 +5418,8 @@ go/token.gox: go/token.lo $(BUILDGOX) go/typechecker.gox: go/typechecker.lo $(BUILDGOX) +go/types.gox: go/types.lo + $(BUILDGOX) hash/adler32.gox: hash/adler32.lo $(BUILDGOX) @@ -5296,15 +5432,25 @@ hash/fnv.gox: hash/fnv.lo http/cgi.gox: http/cgi.lo $(BUILDGOX) +http/fcgi.gox: http/fcgi.lo + $(BUILDGOX) http/httptest.gox: http/httptest.lo $(BUILDGOX) http/pprof.gox: http/pprof.lo $(BUILDGOX) +http/spdy.gox: http/spdy.lo + $(BUILDGOX) +image/gif.gox: image/gif.lo + $(BUILDGOX) image/jpeg.gox: image/jpeg.lo $(BUILDGOX) image/png.gox: image/png.lo $(BUILDGOX) +image/tiff.gox: image/tiff.lo + $(BUILDGOX) +image/ycbcr.gox: image/ycbcr.lo + $(BUILDGOX) index/suffixarray.gox: index/suffixarray.lo $(BUILDGOX) @@ -5322,6 +5468,8 @@ net/textproto.gox: net/textproto.lo os/inotify.gox: os/inotify.lo $(BUILDGOX) +os/user.gox: os/user.lo + $(BUILDGOX) os/signal.gox: os/signal.lo $(BUILDGOX) diff --git a/libgo/go/archive/tar/common.go b/libgo/go/archive/tar/common.go index 5b781ff3d7d..52885876589 100644 --- a/libgo/go/archive/tar/common.go +++ b/libgo/go/archive/tar/common.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The tar package implements access to tar archives. +// Package tar implements access to tar archives. // It aims to cover most of the variations, including those produced // by GNU and BSD tars. // diff --git a/libgo/go/archive/tar/reader.go b/libgo/go/archive/tar/reader.go index 1b41196a999..ad06b6dac54 100644 --- a/libgo/go/archive/tar/reader.go +++ b/libgo/go/archive/tar/reader.go @@ -10,6 +10,7 @@ package tar import ( "bytes" "io" + "io/ioutil" "os" "strconv" ) @@ -27,13 +28,13 @@ var ( // tr := tar.NewReader(r) // for { // hdr, err := tr.Next() -// if err != nil { -// // handle error -// } -// if hdr == nil { +// if err == os.EOF { // // end of tar archive // break // } +// if err != nil { +// // handle error +// } // io.Copy(data, tr) // } type Reader struct { @@ -84,12 +85,6 @@ func (tr *Reader) octal(b []byte) int64 { return int64(x) } -type ignoreWriter struct{} - -func (ignoreWriter) Write(b []byte) (n int, err os.Error) { - return len(b), nil -} - // Skip any unread bytes in the existing file entry, as well as any alignment padding. func (tr *Reader) skipUnread() { nr := tr.nb + tr.pad // number of bytes to skip @@ -99,7 +94,7 @@ func (tr *Reader) skipUnread() { return } } - _, tr.err = io.Copyn(ignoreWriter{}, tr.r, nr) + _, tr.err = io.Copyn(ioutil.Discard, tr.r, nr) } func (tr *Reader) verifyChecksum(header []byte) bool { diff --git a/libgo/go/archive/zip/reader.go b/libgo/go/archive/zip/reader.go index 543007abfe0..17464c5d8e4 100644 --- a/libgo/go/archive/zip/reader.go +++ b/libgo/go/archive/zip/reader.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The zip package provides support for reading ZIP archives. +Package zip provides support for reading ZIP archives. See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT @@ -35,6 +35,11 @@ type Reader struct { Comment string } +type ReadCloser struct { + f *os.File + Reader +} + type File struct { FileHeader zipr io.ReaderAt @@ -47,43 +52,60 @@ func (f *File) hasDataDescriptor() bool { return f.Flags&0x8 != 0 } -// OpenReader will open the Zip file specified by name and return a Reader. -func OpenReader(name string) (*Reader, os.Error) { +// OpenReader will open the Zip file specified by name and return a ReaderCloser. +func OpenReader(name string) (*ReadCloser, os.Error) { f, err := os.Open(name) if err != nil { return nil, err } fi, err := f.Stat() if err != nil { + f.Close() + return nil, err + } + r := new(ReadCloser) + if err := r.init(f, fi.Size); err != nil { + f.Close() return nil, err } - return NewReader(f, fi.Size) + return r, nil } // NewReader returns a new Reader reading from r, which is assumed to // have the given size in bytes. func NewReader(r io.ReaderAt, size int64) (*Reader, os.Error) { - end, err := readDirectoryEnd(r, size) - if err != nil { + zr := new(Reader) + if err := zr.init(r, size); err != nil { return nil, err } - z := &Reader{ - r: r, - File: make([]*File, end.directoryRecords), - Comment: end.comment, + return zr, nil +} + +func (z *Reader) init(r io.ReaderAt, size int64) os.Error { + end, err := readDirectoryEnd(r, size) + if err != nil { + return err } + z.r = r + z.File = make([]*File, end.directoryRecords) + z.Comment = end.comment rs := io.NewSectionReader(r, 0, size) if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil { - return nil, err + return err } buf := bufio.NewReader(rs) for i := range z.File { z.File[i] = &File{zipr: r, zipsize: size} if err := readDirectoryHeader(z.File[i], buf); err != nil { - return nil, err + return err } } - return z, nil + return nil +} + +// Close closes the Zip file, rendering it unusable for I/O. +func (rc *ReadCloser) Close() os.Error { + return rc.f.Close() } // Open returns a ReadCloser that provides access to the File's contents. diff --git a/libgo/go/archive/zip/reader_test.go b/libgo/go/archive/zip/reader_test.go index 72e8cccfd47..c72cd9a2347 100644 --- a/libgo/go/archive/zip/reader_test.go +++ b/libgo/go/archive/zip/reader_test.go @@ -76,6 +76,12 @@ func readTestZip(t *testing.T, zt ZipTest) { return } + // bail if file is not zip + if err == FormatError { + return + } + defer z.Close() + // bail here if no Files expected to be tested // (there may actually be files in the zip, but we don't care) if zt.File == nil { diff --git a/libgo/go/asn1/asn1.go b/libgo/go/asn1/asn1.go index c5314517b34..5f470aed797 100644 --- a/libgo/go/asn1/asn1.go +++ b/libgo/go/asn1/asn1.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The asn1 package implements parsing of DER-encoded ASN.1 data structures, +// Package asn1 implements parsing of DER-encoded ASN.1 data structures, // as defined in ITU-T Rec X.690. // // See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' @@ -373,7 +373,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i // parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse // a number of ASN.1 values from the given byte array and returns them as a // slice of Go values of the given type. -func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflect.Type) (ret *reflect.SliceValue, err os.Error) { +func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) { expectedTag, compoundType, ok := getUniversalType(elemType) if !ok { err = StructuralError{"unknown Go type for slice"} @@ -409,7 +409,7 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec params := fieldParameters{} offset := 0 for i := 0; i < numElements; i++ { - offset, err = parseField(ret.Elem(i), bytes, offset, params) + offset, err = parseField(ret.Index(i), bytes, offset, params) if err != nil { return } @@ -418,13 +418,13 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec } var ( - bitStringType = reflect.Typeof(BitString{}) - objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) - enumeratedType = reflect.Typeof(Enumerated(0)) - flagType = reflect.Typeof(Flag(false)) - timeType = reflect.Typeof(&time.Time{}) - rawValueType = reflect.Typeof(RawValue{}) - rawContentsType = reflect.Typeof(RawContent(nil)) + bitStringType = reflect.TypeOf(BitString{}) + objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) + enumeratedType = reflect.TypeOf(Enumerated(0)) + flagType = reflect.TypeOf(Flag(false)) + timeType = reflect.TypeOf(&time.Time{}) + rawValueType = reflect.TypeOf(RawValue{}) + rawContentsType = reflect.TypeOf(RawContent(nil)) ) // invalidLength returns true iff offset + length > sliceLength, or if the @@ -461,13 +461,12 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} offset += t.length - v.(*reflect.StructValue).Set(reflect.NewValue(result).(*reflect.StructValue)) + v.Set(reflect.ValueOf(result)) return } // Deal with the ANY type. - if ifaceType, ok := fieldType.(*reflect.InterfaceType); ok && ifaceType.NumMethod() == 0 { - ifaceValue := v.(*reflect.InterfaceValue) + if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 { var t tagAndLength t, offset, err = parseTagAndLength(bytes, offset) if err != nil { @@ -506,7 +505,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam return } if result != nil { - ifaceValue.Set(reflect.NewValue(result)) + v.Set(reflect.ValueOf(result)) } return } @@ -536,9 +535,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam err = StructuralError{"Zero length explicit tag was not an asn1.Flag"} return } - - flagValue := v.(*reflect.BoolValue) - flagValue.Set(true) + v.SetBool(true) return } } else { @@ -606,23 +603,20 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam switch fieldType { case objectIdentifierType: newSlice, err1 := parseObjectIdentifier(innerBytes) - sliceValue := v.(*reflect.SliceValue) - sliceValue.Set(reflect.MakeSlice(sliceValue.Type().(*reflect.SliceType), len(newSlice), len(newSlice))) + v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) if err1 == nil { - reflect.Copy(sliceValue, reflect.NewValue(newSlice).(reflect.ArrayOrSliceValue)) + reflect.Copy(v, reflect.ValueOf(newSlice)) } err = err1 return case bitStringType: - structValue := v.(*reflect.StructValue) bs, err1 := parseBitString(innerBytes) if err1 == nil { - structValue.Set(reflect.NewValue(bs).(*reflect.StructValue)) + v.Set(reflect.ValueOf(bs)) } err = err1 return case timeType: - ptrValue := v.(*reflect.PtrValue) var time *time.Time var err1 os.Error if universalTag == tagUTCTime { @@ -631,55 +625,53 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam time, err1 = parseGeneralizedTime(innerBytes) } if err1 == nil { - ptrValue.Set(reflect.NewValue(time).(*reflect.PtrValue)) + v.Set(reflect.ValueOf(time)) } err = err1 return case enumeratedType: parsedInt, err1 := parseInt(innerBytes) - enumValue := v.(*reflect.IntValue) if err1 == nil { - enumValue.Set(int64(parsedInt)) + v.SetInt(int64(parsedInt)) } err = err1 return case flagType: - flagValue := v.(*reflect.BoolValue) - flagValue.Set(true) + v.SetBool(true) return } - switch val := v.(type) { - case *reflect.BoolValue: + switch val := v; val.Kind() { + case reflect.Bool: parsedBool, err1 := parseBool(innerBytes) if err1 == nil { - val.Set(parsedBool) + val.SetBool(parsedBool) } err = err1 return - case *reflect.IntValue: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: switch val.Type().Kind() { case reflect.Int: parsedInt, err1 := parseInt(innerBytes) if err1 == nil { - val.Set(int64(parsedInt)) + val.SetInt(int64(parsedInt)) } err = err1 return case reflect.Int64: parsedInt, err1 := parseInt64(innerBytes) if err1 == nil { - val.Set(parsedInt) + val.SetInt(parsedInt) } err = err1 return } - case *reflect.StructValue: - structType := fieldType.(*reflect.StructType) + case reflect.Struct: + structType := fieldType if structType.NumField() > 0 && structType.Field(0).Type == rawContentsType { bytes := bytes[initOffset:offset] - val.Field(0).SetValue(reflect.NewValue(RawContent(bytes))) + val.Field(0).Set(reflect.ValueOf(RawContent(bytes))) } innerOffset := 0 @@ -697,11 +689,11 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam // adding elements to the end has been used in X.509 as the // version numbers have increased. return - case *reflect.SliceValue: - sliceType := fieldType.(*reflect.SliceType) + case reflect.Slice: + sliceType := fieldType if sliceType.Elem().Kind() == reflect.Uint8 { val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) - reflect.Copy(val, reflect.NewValue(innerBytes).(reflect.ArrayOrSliceValue)) + reflect.Copy(val, reflect.ValueOf(innerBytes)) return } newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) @@ -710,7 +702,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } err = err1 return - case *reflect.StringValue: + case reflect.String: var v string switch universalTag { case tagPrintableString: @@ -729,7 +721,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)} } if err == nil { - val.Set(v) + val.SetString(v) } return } @@ -748,9 +740,9 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { if params.defaultValue == nil { return } - switch val := v.(type) { - case *reflect.IntValue: - val.Set(*params.defaultValue) + switch val := v; val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val.SetInt(*params.defaultValue) } return } @@ -806,7 +798,7 @@ func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { // UnmarshalWithParams allows field parameters to be specified for the // top-level element. The form of the params is the same as the field tags. func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { - v := reflect.NewValue(val).(*reflect.PtrValue).Elem() + v := reflect.ValueOf(val).Elem() offset, err := parseField(v, b, 0, parseFieldParameters(params)) if err != nil { return nil, err diff --git a/libgo/go/asn1/asn1_test.go b/libgo/go/asn1/asn1_test.go index b7767656a42..78f56280524 100644 --- a/libgo/go/asn1/asn1_test.go +++ b/libgo/go/asn1/asn1_test.go @@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) { } } -type unmarshalTest struct { - in []byte - out interface{} -} - type TestObjectIdentifierStruct struct { OID ObjectIdentifier } @@ -290,7 +285,10 @@ type TestElementsAfterString struct { A, B int } -var unmarshalTestData []unmarshalTest = []unmarshalTest{ +var unmarshalTestData = []struct { + in []byte + out interface{} +}{ {[]byte{0x02, 0x01, 0x42}, newInt(0x42)}, {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}}, {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}}, @@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{ func TestUnmarshal(t *testing.T) { for i, test := range unmarshalTestData { - pv := reflect.MakeZero(reflect.NewValue(test.out).Type()) - zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem()) - pv.(*reflect.PtrValue).PointTo(zv) + pv := reflect.New(reflect.TypeOf(test.out).Elem()) val := pv.Interface() _, err := Unmarshal(test.in, val) if err != nil { diff --git a/libgo/go/asn1/common.go b/libgo/go/asn1/common.go index f2254a41bba..1589877477c 100644 --- a/libgo/go/asn1/common.go +++ b/libgo/go/asn1/common.go @@ -133,14 +133,14 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { case enumeratedType: return tagEnum, false, true } - switch t := t.(type) { - case *reflect.BoolType: + switch t.Kind() { + case reflect.Bool: return tagBoolean, false, true - case *reflect.IntType: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return tagInteger, false, true - case *reflect.StructType: + case reflect.Struct: return tagSequence, true, true - case *reflect.SliceType: + case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { return tagOctetString, false, true } @@ -148,7 +148,7 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { return tagSet, true, true } return tagSequence, true, true - case *reflect.StringType: + case reflect.String: return tagPrintableString, false, true } return 0, false, false diff --git a/libgo/go/asn1/marshal.go b/libgo/go/asn1/marshal.go index 57b8f20ba7f..a3e1145b895 100644 --- a/libgo/go/asn1/marshal.go +++ b/libgo/go/asn1/marshal.go @@ -125,6 +125,28 @@ func int64Length(i int64) (numBytes int) { return } +func marshalLength(out *forkableWriter, i int) (err os.Error) { + n := lengthLength(i) + + for ; n > 0; n-- { + err = out.WriteByte(byte(i >> uint((n-1)*8))) + if err != nil { + return + } + } + + return nil +} + +func lengthLength(i int) (numBytes int) { + numBytes = 1 + for i > 255 { + numBytes++ + i >>= 8 + } + return +} + func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) { b := uint8(t.class) << 6 if t.isCompound { @@ -149,12 +171,12 @@ func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) { } if t.length >= 128 { - l := int64Length(int64(t.length)) + l := lengthLength(t.length) err = out.WriteByte(0x80 | byte(l)) if err != nil { return } - err = marshalInt64(out, int64(t.length)) + err = marshalLength(out, t.length) if err != nil { return } @@ -314,28 +336,28 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) } - switch v := value.(type) { - case *reflect.BoolValue: - if v.Get() { + switch v := value; v.Kind() { + case reflect.Bool: + if v.Bool() { return out.WriteByte(255) } else { return out.WriteByte(0) } - case *reflect.IntValue: - return marshalInt64(out, int64(v.Get())) - case *reflect.StructValue: - t := v.Type().(*reflect.StructType) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return marshalInt64(out, int64(v.Int())) + case reflect.Struct: + t := v.Type() startingField := 0 // If the first element of the structure is a non-empty // RawContents, then we don't bother serialising the rest. if t.NumField() > 0 && t.Field(0).Type == rawContentsType { - s := v.Field(0).(*reflect.SliceValue) + s := v.Field(0) if s.Len() > 0 { bytes := make([]byte, s.Len()) for i := 0; i < s.Len(); i++ { - bytes[i] = uint8(s.Elem(i).(*reflect.UintValue).Get()) + bytes[i] = uint8(s.Index(i).Uint()) } /* The RawContents will contain the tag and * length fields but we'll also be writing @@ -357,12 +379,12 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter } } return - case *reflect.SliceValue: - sliceType := v.Type().(*reflect.SliceType) + case reflect.Slice: + sliceType := v.Type() if sliceType.Elem().Kind() == reflect.Uint8 { bytes := make([]byte, v.Len()) for i := 0; i < v.Len(); i++ { - bytes[i] = uint8(v.Elem(i).(*reflect.UintValue).Get()) + bytes[i] = uint8(v.Index(i).Uint()) } _, err = out.Write(bytes) return @@ -372,17 +394,17 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter for i := 0; i < v.Len(); i++ { var pre *forkableWriter pre, out = out.fork() - err = marshalField(pre, v.Elem(i), params) + err = marshalField(pre, v.Index(i), params) if err != nil { return } } return - case *reflect.StringValue: + case reflect.String: if params.stringType == tagIA5String { - return marshalIA5String(out, v.Get()) + return marshalIA5String(out, v.String()) } else { - return marshalPrintableString(out, v.Get()) + return marshalPrintableString(out, v.String()) } return } @@ -392,7 +414,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err os.Error) { // If the field is an interface{} then recurse into it. - if v, ok := v.(*reflect.InterfaceValue); ok && v.Type().(*reflect.InterfaceType).NumMethod() == 0 { + if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { return marshalField(out, v.Elem(), params) } @@ -406,7 +428,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) return } - if params.optional && reflect.DeepEqual(v.Interface(), reflect.MakeZero(v.Type()).Interface()) { + if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { return } @@ -471,7 +493,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) // Marshal returns the ASN.1 encoding of val. func Marshal(val interface{}) ([]byte, os.Error) { var out bytes.Buffer - v := reflect.NewValue(val) + v := reflect.ValueOf(val) f := newForkableWriter() err := marshalField(f, v, fieldParameters{}) if err != nil { diff --git a/libgo/go/asn1/marshal_test.go b/libgo/go/asn1/marshal_test.go index 85eafc9e4d2..cd165d20352 100644 --- a/libgo/go/asn1/marshal_test.go +++ b/libgo/go/asn1/marshal_test.go @@ -77,6 +77,30 @@ var marshalTests = []marshalTest{ {ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"}, {ObjectIdentifier([]int{1, 2, 840, 133549, 1, 1, 5}), "06092a864888932d010105"}, {"test", "130474657374"}, + { + "" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // This is 127 times 'x' + "137f" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "78787878787878787878787878787878787878787878787878787878787878", + }, + { + "" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // This is 128 times 'x' + "138180" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878" + + "7878787878787878787878787878787878787878787878787878787878787878", + }, {ia5StringTest{"test"}, "3006160474657374"}, {printableStringTest{"test"}, "3006130474657374"}, {printableStringTest{"test*"}, "30071305746573742a"}, diff --git a/libgo/go/big/int.go b/libgo/go/big/int.go index ecd70e03ef1..f1ea7b1c2ec 100644 --- a/libgo/go/big/int.go +++ b/libgo/go/big/int.go @@ -337,6 +337,10 @@ func fmtbase(ch int) int { // 'x' (hexadecimal). // func (x *Int) Format(s fmt.State, ch int) { + if x == nil { + fmt.Fprint(s, "") + return + } if x.neg { fmt.Fprint(s, "-") } diff --git a/libgo/go/big/nat.go b/libgo/go/big/nat.go index a04d3b1d9c1..4848d427b39 100644 --- a/libgo/go/big/nat.go +++ b/libgo/go/big/nat.go @@ -2,11 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This file contains operations on unsigned multi-precision integers. -// These are the building blocks for the operations on signed integers -// and rationals. - -// This package implements multi-precision arithmetic (big numbers). +// Package big implements multi-precision arithmetic (big numbers). // The following numeric types are supported: // // - Int signed integers @@ -18,6 +14,10 @@ // package big +// This file contains operations on unsigned multi-precision integers. +// These are the building blocks for the operations on signed integers +// and rationals. + import "rand" // An unsigned integer x of the form diff --git a/libgo/go/bufio/bufio.go b/libgo/go/bufio/bufio.go index cd08be31b6a..eaae8bb42c3 100644 --- a/libgo/go/bufio/bufio.go +++ b/libgo/go/bufio/bufio.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements buffered I/O. It wraps an io.Reader or io.Writer +// Package bufio implements buffered I/O. It wraps an io.Reader or io.Writer // object, creating another object (Reader or Writer) that also implements // the interface but provides buffering and some help for textual I/O. package bufio @@ -282,6 +282,33 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) { panic("not reached") } +// ReadLine tries to return a single line, not including the end-of-line bytes. +// If the line was too long for the buffer then isPrefix is set and the +// beginning of the line is returned. The rest of the line will be returned +// from future calls. isPrefix will be false when returning the last fragment +// of the line. The returned buffer is only valid until the next call to +// ReadLine. ReadLine either returns a non-nil line or it returns an error, +// never both. +func (b *Reader) ReadLine() (line []byte, isPrefix bool, err os.Error) { + line, err = b.ReadSlice('\n') + if err == ErrBufferFull { + return line, true, nil + } + + if len(line) == 0 { + return + } + err = nil + + if line[len(line)-1] == '\n' { + line = line[:len(line)-1] + } + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + return +} + // ReadBytes reads until the first occurrence of delim in the input, // returning a slice containing the data up to and including the delimiter. // If ReadBytes encounters an error before finding a delimiter, diff --git a/libgo/go/bufio/bufio_test.go b/libgo/go/bufio/bufio_test.go index 8028e04dcd9..123adac29a4 100644 --- a/libgo/go/bufio/bufio_test.go +++ b/libgo/go/bufio/bufio_test.go @@ -9,6 +9,7 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "os" "strings" "testing" @@ -570,3 +571,128 @@ func TestPeekThenUnreadRune(t *testing.T) { r.UnreadRune() r.ReadRune() // Used to panic here } + +var testOutput = []byte("0123456789abcdefghijklmnopqrstuvwxy") +var testInput = []byte("012\n345\n678\n9ab\ncde\nfgh\nijk\nlmn\nopq\nrst\nuvw\nxy") +var testInputrn = []byte("012\r\n345\r\n678\r\n9ab\r\ncde\r\nfgh\r\nijk\r\nlmn\r\nopq\r\nrst\r\nuvw\r\nxy\r\n\n\r\n") + +// TestReader wraps a []byte and returns reads of a specific length. +type testReader struct { + data []byte + stride int +} + +func (t *testReader) Read(buf []byte) (n int, err os.Error) { + n = t.stride + if n > len(t.data) { + n = len(t.data) + } + if n > len(buf) { + n = len(buf) + } + copy(buf, t.data) + t.data = t.data[n:] + if len(t.data) == 0 { + err = os.EOF + } + return +} + +func testReadLine(t *testing.T, input []byte) { + //for stride := 1; stride < len(input); stride++ { + for stride := 1; stride < 2; stride++ { + done := 0 + reader := testReader{input, stride} + l, _ := NewReaderSize(&reader, len(input)+1) + for { + line, isPrefix, err := l.ReadLine() + if len(line) > 0 && err != nil { + t.Errorf("ReadLine returned both data and error: %s", err) + } + if isPrefix { + t.Errorf("ReadLine returned prefix") + } + if err != nil { + if err != os.EOF { + t.Fatalf("Got unknown error: %s", err) + } + break + } + if want := testOutput[done : done+len(line)]; !bytes.Equal(want, line) { + t.Errorf("Bad line at stride %d: want: %x got: %x", stride, want, line) + } + done += len(line) + } + if done != len(testOutput) { + t.Errorf("ReadLine didn't return everything: got: %d, want: %d (stride: %d)", done, len(testOutput), stride) + } + } +} + +func TestReadLine(t *testing.T) { + testReadLine(t, testInput) + testReadLine(t, testInputrn) +} + +func TestLineTooLong(t *testing.T) { + buf := bytes.NewBuffer([]byte("aaabbbcc\n")) + l, _ := NewReaderSize(buf, 3) + line, isPrefix, err := l.ReadLine() + if !isPrefix || !bytes.Equal(line, []byte("aaa")) || err != nil { + t.Errorf("bad result for first line: %x %s", line, err) + } + line, isPrefix, err = l.ReadLine() + if !isPrefix || !bytes.Equal(line, []byte("bbb")) || err != nil { + t.Errorf("bad result for second line: %x", line) + } + line, isPrefix, err = l.ReadLine() + if isPrefix || !bytes.Equal(line, []byte("cc")) || err != nil { + t.Errorf("bad result for third line: %x", line) + } + line, isPrefix, err = l.ReadLine() + if isPrefix || err == nil { + t.Errorf("expected no more lines: %x %s", line, err) + } +} + +func TestReadAfterLines(t *testing.T) { + line1 := "line1" + restData := "line2\nline 3\n" + inbuf := bytes.NewBuffer([]byte(line1 + "\n" + restData)) + outbuf := new(bytes.Buffer) + maxLineLength := len(line1) + len(restData)/2 + l, _ := NewReaderSize(inbuf, maxLineLength) + line, isPrefix, err := l.ReadLine() + if isPrefix || err != nil || string(line) != line1 { + t.Errorf("bad result for first line: isPrefix=%v err=%v line=%q", isPrefix, err, string(line)) + } + n, err := io.Copy(outbuf, l) + if int(n) != len(restData) || err != nil { + t.Errorf("bad result for Read: n=%d err=%v", n, err) + } + if outbuf.String() != restData { + t.Errorf("bad result for Read: got %q; expected %q", outbuf.String(), restData) + } +} + +func TestReadEmptyBuffer(t *testing.T) { + l, _ := NewReaderSize(bytes.NewBuffer(nil), 10) + line, isPrefix, err := l.ReadLine() + if err != os.EOF { + t.Errorf("expected EOF from ReadLine, got '%s' %t %s", line, isPrefix, err) + } +} + +func TestLinesAfterRead(t *testing.T) { + l, _ := NewReaderSize(bytes.NewBuffer([]byte("foo")), 10) + _, err := ioutil.ReadAll(l) + if err != nil { + t.Error(err) + return + } + + line, isPrefix, err := l.ReadLine() + if err != os.EOF { + t.Errorf("expected EOF from ReadLine, got '%s' %t %s", line, isPrefix, err) + } +} diff --git a/libgo/go/bytes/bytes.go b/libgo/go/bytes/bytes.go index c12a1357383..0f9ac986371 100644 --- a/libgo/go/bytes/bytes.go +++ b/libgo/go/bytes/bytes.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The bytes package implements functions for the manipulation of byte slices. -// Analogous to the facilities of the strings package. +// Package bytes implements functions for the manipulation of byte slices. +// It is analogous to the facilities of the strings package. package bytes import ( diff --git a/libgo/go/cmath/abs.go b/libgo/go/cmath/abs.go index 725dc4e9821..f3199cad561 100644 --- a/libgo/go/cmath/abs.go +++ b/libgo/go/cmath/abs.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The cmath package provides basic constants -// and mathematical functions for complex numbers. +// Package cmath provides basic constants and mathematical functions for +// complex numbers. package cmath import "math" diff --git a/libgo/go/compress/flate/deflate.go b/libgo/go/compress/flate/deflate.go index 591b35c4463..a02a5e8d94b 100644 --- a/libgo/go/compress/flate/deflate.go +++ b/libgo/go/compress/flate/deflate.go @@ -143,10 +143,18 @@ func (d *compressor) fillWindow(index int) (int, os.Error) { d.blockStart = math.MaxInt32 } for i, h := range d.hashHead { - d.hashHead[i] = max(h-wSize, -1) + v := h - wSize + if v < -1 { + v = -1 + } + d.hashHead[i] = v } for i, h := range d.hashPrev { - d.hashPrev[i] = max(h-wSize, -1) + v := -h - wSize + if v < -1 { + v = -1 + } + d.hashPrev[i] = v } } count, err := d.r.Read(d.window[d.windowEnd:]) @@ -177,10 +185,18 @@ func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error { // Try to find a match starting at index whose length is greater than prevSize. // We only look at chainCount possibilities before giving up. func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) { - win := d.window[0 : pos+min(maxMatchLength, lookahead)] + minMatchLook := maxMatchLength + if lookahead < minMatchLook { + minMatchLook = lookahead + } + + win := d.window[0 : pos+minMatchLook] // We quit when we get a match that's at least nice long - nice := min(d.niceMatch, len(win)-pos) + nice := len(win) - pos + if d.niceMatch < nice { + nice = d.niceMatch + } // If we've got a match that's good enough, only look in 1/4 the chain. tries := d.maxChainLength @@ -344,9 +360,12 @@ Loop: } prevLength := length prevOffset := offset - minIndex := max(index-maxOffset, 0) length = minMatchLength - 1 offset = 0 + minIndex := index - maxOffset + if minIndex < 0 { + minIndex = 0 + } if chainHead >= minIndex && (isFastDeflate && lookahead > minMatchLength-1 || @@ -477,6 +496,33 @@ func NewWriter(w io.Writer, level int) *Writer { return &Writer{pw, &d} } +// NewWriterDict is like NewWriter but initializes the new +// Writer with a preset dictionary. The returned Writer behaves +// as if the dictionary had been written to it without producing +// any compressed output. The compressed data written to w +// can only be decompressed by a Reader initialized with the +// same dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) *Writer { + dw := &dictWriter{w, false} + zw := NewWriter(dw, level) + zw.Write(dict) + zw.Flush() + dw.enabled = true + return zw +} + +type dictWriter struct { + w io.Writer + enabled bool +} + +func (w *dictWriter) Write(b []byte) (n int, err os.Error) { + if w.enabled { + return w.w.Write(b) + } + return len(b), nil +} + // A Writer takes data written to it and writes the compressed // form of that data to an underlying writer (see NewWriter). type Writer struct { diff --git a/libgo/go/compress/flate/deflate_test.go b/libgo/go/compress/flate/deflate_test.go index ed5884a4b78..650a8059ace 100644 --- a/libgo/go/compress/flate/deflate_test.go +++ b/libgo/go/compress/flate/deflate_test.go @@ -275,3 +275,49 @@ func TestDeflateInflateString(t *testing.T) { } testToFromWithLevel(t, 1, gold, "2.718281828...") } + +func TestReaderDict(t *testing.T) { + const ( + dict = "hello world" + text = "hello again world" + ) + var b bytes.Buffer + w := NewWriter(&b, 5) + w.Write([]byte(dict)) + w.Flush() + b.Reset() + w.Write([]byte(text)) + w.Close() + + r := NewReaderDict(&b, []byte(dict)) + data, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello again world" { + t.Fatalf("read returned %q want %q", string(data), text) + } +} + +func TestWriterDict(t *testing.T) { + const ( + dict = "hello world" + text = "hello again world" + ) + var b bytes.Buffer + w := NewWriter(&b, 5) + w.Write([]byte(dict)) + w.Flush() + b.Reset() + w.Write([]byte(text)) + w.Close() + + var b1 bytes.Buffer + w = NewWriterDict(&b1, 5, []byte(dict)) + w.Write([]byte(text)) + w.Close() + + if !bytes.Equal(b1.Bytes(), b.Bytes()) { + t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes()) + } +} diff --git a/libgo/go/compress/flate/inflate.go b/libgo/go/compress/flate/inflate.go index 7dc8cf93bd9..320b80d0699 100644 --- a/libgo/go/compress/flate/inflate.go +++ b/libgo/go/compress/flate/inflate.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The flate package implements the DEFLATE compressed data -// format, described in RFC 1951. The gzip and zlib packages -// implement access to DEFLATE-based file formats. +// Package flate implements the DEFLATE compressed data format, described in +// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file +// formats. package flate import ( @@ -526,6 +526,20 @@ func (f *decompressor) dataBlock() os.Error { return nil } +func (f *decompressor) setDict(dict []byte) { + if len(dict) > len(f.hist) { + // Will only remember the tail. + dict = dict[len(dict)-len(f.hist):] + } + + f.hp = copy(f.hist[:], dict) + if f.hp == len(f.hist) { + f.hp = 0 + f.hfull = true + } + f.hw = f.hp +} + func (f *decompressor) moreBits() os.Error { c, err := f.r.ReadByte() if err != nil { @@ -618,3 +632,16 @@ func NewReader(r io.Reader) io.ReadCloser { go func() { pw.CloseWithError(f.decompress(r, pw)) }() return pr } + +// NewReaderDict is like NewReader but initializes the reader +// with a preset dictionary. The returned Reader behaves as if +// the uncompressed data stream started with the given dictionary, +// which has already been read. NewReaderDict is typically used +// to read data compressed by NewWriterDict. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + var f decompressor + f.setDict(dict) + pr, pw := io.Pipe() + go func() { pw.CloseWithError(f.decompress(r, pw)) }() + return pr +} diff --git a/libgo/go/compress/gzip/gunzip.go b/libgo/go/compress/gzip/gunzip.go index 3c0b3c5e5fc..b0ddc81d252 100644 --- a/libgo/go/compress/gzip/gunzip.go +++ b/libgo/go/compress/gzip/gunzip.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The gzip package implements reading and writing of -// gzip format compressed files, as specified in RFC 1952. +// Package gzip implements reading and writing of gzip format compressed files, +// as specified in RFC 1952. package gzip import ( diff --git a/libgo/go/compress/lzw/reader.go b/libgo/go/compress/lzw/reader.go index 8a540cbe6a1..a1cd2abc043 100644 --- a/libgo/go/compress/lzw/reader.go +++ b/libgo/go/compress/lzw/reader.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The lzw package implements the Lempel-Ziv-Welch compressed data format, +// Package lzw implements the Lempel-Ziv-Welch compressed data format, // described in T. A. Welch, ``A Technique for High-Performance Data // Compression'', Computer, 17(6) (June 1984), pp 8-19. // @@ -165,16 +165,19 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os if _, err := w.Write(buf[i:]); err != nil { return err } - // Save what the hi code expands to. - suffix[hi] = uint8(c) - prefix[hi] = last + if last != invalidCode { + // Save what the hi code expands to. + suffix[hi] = uint8(c) + prefix[hi] = last + } default: return os.NewError("lzw: invalid code") } last, hi = code, hi+1 - if hi == overflow { + if hi >= overflow { if d.width == maxWidth { - return os.NewError("lzw: missing clear code") + last = invalidCode + continue } d.width++ overflow <<= 1 diff --git a/libgo/go/compress/lzw/reader_test.go b/libgo/go/compress/lzw/reader_test.go index 4b5dfaadea2..72121a6b569 100644 --- a/libgo/go/compress/lzw/reader_test.go +++ b/libgo/go/compress/lzw/reader_test.go @@ -112,12 +112,6 @@ func TestReader(t *testing.T) { } } -type devNull struct{} - -func (devNull) Write(p []byte) (int, os.Error) { - return len(p), nil -} - func benchmarkDecoder(b *testing.B, n int) { b.StopTimer() b.SetBytes(int64(n)) @@ -134,7 +128,7 @@ func benchmarkDecoder(b *testing.B, n int) { runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { - io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8)) + io.Copy(ioutil.Discard, NewReader(bytes.NewBuffer(buf1), LSB, 8)) } } diff --git a/libgo/go/compress/lzw/writer_test.go b/libgo/go/compress/lzw/writer_test.go index e5815a03d5d..82464ecd1b0 100644 --- a/libgo/go/compress/lzw/writer_test.go +++ b/libgo/go/compress/lzw/writer_test.go @@ -113,7 +113,7 @@ func benchmarkEncoder(b *testing.B, n int) { runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { - w := NewWriter(devNull{}, LSB, 8) + w := NewWriter(ioutil.Discard, LSB, 8) w.Write(buf1) w.Close() } diff --git a/libgo/go/compress/zlib/reader.go b/libgo/go/compress/zlib/reader.go index 721f6ec5595..8a3ef1580aa 100644 --- a/libgo/go/compress/zlib/reader.go +++ b/libgo/go/compress/zlib/reader.go @@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. /* -The zlib package implements reading and writing of zlib -format compressed data, as specified in RFC 1950. +Package zlib implements reading and writing of zlib format compressed data, +as specified in RFC 1950. The implementation provides filters that uncompress during reading and compress during writing. For example, to write compressed data @@ -36,7 +36,7 @@ const zlibDeflate = 8 var ChecksumError os.Error = os.ErrorString("zlib checksum error") var HeaderError os.Error = os.ErrorString("invalid zlib header") -var UnsupportedError os.Error = os.ErrorString("unsupported zlib format") +var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary") type reader struct { r flate.Reader @@ -50,6 +50,12 @@ type reader struct { // The implementation buffers input and may read more data than necessary from r. // It is the caller's responsibility to call Close on the ReadCloser when done. func NewReader(r io.Reader) (io.ReadCloser, os.Error) { + return NewReaderDict(r, nil) +} + +// NewReaderDict is like NewReader but uses a preset dictionary. +// NewReaderDict ignores the dictionary if the compressed data does not refer to it. +func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, os.Error) { z := new(reader) if fr, ok := r.(flate.Reader); ok { z.r = fr @@ -65,11 +71,19 @@ func NewReader(r io.Reader) (io.ReadCloser, os.Error) { return nil, HeaderError } if z.scratch[1]&0x20 != 0 { - // BUG(nigeltao): The zlib package does not implement the FDICT flag. - return nil, UnsupportedError + _, err = io.ReadFull(z.r, z.scratch[0:4]) + if err != nil { + return nil, err + } + checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) + if checksum != adler32.Checksum(dict) { + return nil, DictionaryError + } + z.decompressor = flate.NewReaderDict(z.r, dict) + } else { + z.decompressor = flate.NewReader(z.r) } z.digest = adler32.New() - z.decompressor = flate.NewReader(z.r) return z, nil } diff --git a/libgo/go/compress/zlib/reader_test.go b/libgo/go/compress/zlib/reader_test.go index eaefc3a361a..195db446c9f 100644 --- a/libgo/go/compress/zlib/reader_test.go +++ b/libgo/go/compress/zlib/reader_test.go @@ -15,6 +15,7 @@ type zlibTest struct { desc string raw string compressed []byte + dict []byte err os.Error } @@ -27,6 +28,7 @@ var zlibTests = []zlibTest{ "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, nil, + nil, }, { "goodbye", @@ -37,23 +39,27 @@ var zlibTests = []zlibTest{ 0x01, 0x00, 0x28, 0xa5, 0x05, 0x5e, }, nil, + nil, }, { "bad header", "", []byte{0x78, 0x9f, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, + nil, HeaderError, }, { "bad checksum", "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0xff}, + nil, ChecksumError, }, { "not enough data", "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00}, + nil, io.ErrUnexpectedEOF, }, { @@ -64,6 +70,33 @@ var zlibTests = []zlibTest{ 0x78, 0x9c, 0xff, }, nil, + nil, + }, + { + "dictionary", + "Hello, World!\n", + []byte{ + 0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00, + 0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24, + 0x12, 0x04, 0x74, + }, + []byte{ + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x0a, + }, + nil, + }, + { + "wrong dictionary", + "", + []byte{ + 0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00, + 0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24, + 0x12, 0x04, 0x74, + }, + []byte{ + 0x48, 0x65, 0x6c, 0x6c, + }, + DictionaryError, }, } @@ -71,7 +104,7 @@ func TestDecompressor(t *testing.T) { b := new(bytes.Buffer) for _, tt := range zlibTests { in := bytes.NewBuffer(tt.compressed) - zlib, err := NewReader(in) + zlib, err := NewReaderDict(in, tt.dict) if err != nil { if err != tt.err { t.Errorf("%s: NewReader: %s", tt.desc, err) diff --git a/libgo/go/compress/zlib/writer.go b/libgo/go/compress/zlib/writer.go index 031586cd2bb..f1f9b285375 100644 --- a/libgo/go/compress/zlib/writer.go +++ b/libgo/go/compress/zlib/writer.go @@ -21,56 +21,80 @@ const ( DefaultCompression = flate.DefaultCompression ) -type writer struct { +// A Writer takes data written to it and writes the compressed +// form of that data to an underlying writer (see NewWriter). +type Writer struct { w io.Writer - compressor io.WriteCloser + compressor *flate.Writer digest hash.Hash32 err os.Error scratch [4]byte } // NewWriter calls NewWriterLevel with the default compression level. -func NewWriter(w io.Writer) (io.WriteCloser, os.Error) { +func NewWriter(w io.Writer) (*Writer, os.Error) { return NewWriterLevel(w, DefaultCompression) } -// NewWriterLevel creates a new io.WriteCloser that satisfies writes by compressing data written to w. +// NewWriterLevel calls NewWriterDict with no dictionary. +func NewWriterLevel(w io.Writer, level int) (*Writer, os.Error) { + return NewWriterDict(w, level, nil) +} + +// NewWriterDict creates a new io.WriteCloser that satisfies writes by compressing data written to w. // It is the caller's responsibility to call Close on the WriteCloser when done. // level is the compression level, which can be DefaultCompression, NoCompression, // or any integer value between BestSpeed and BestCompression (inclusive). -func NewWriterLevel(w io.Writer, level int) (io.WriteCloser, os.Error) { - z := new(writer) +// dict is the preset dictionary to compress with, or nil to use no dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) { + z := new(Writer) // ZLIB has a two-byte header (as documented in RFC 1950). // The first four bits is the CINFO (compression info), which is 7 for the default deflate window size. // The next four bits is the CM (compression method), which is 8 for deflate. z.scratch[0] = 0x78 // The next two bits is the FLEVEL (compression level). The four values are: // 0=fastest, 1=fast, 2=default, 3=best. - // The next bit, FDICT, is unused, in this implementation. + // The next bit, FDICT, is set if a dictionary is given. // The final five FCHECK bits form a mod-31 checksum. switch level { case 0, 1: - z.scratch[1] = 0x01 + z.scratch[1] = 0 << 6 case 2, 3, 4, 5: - z.scratch[1] = 0x5e + z.scratch[1] = 1 << 6 case 6, -1: - z.scratch[1] = 0x9c + z.scratch[1] = 2 << 6 case 7, 8, 9: - z.scratch[1] = 0xda + z.scratch[1] = 3 << 6 default: return nil, os.NewError("level out of range") } + if dict != nil { + z.scratch[1] |= 1 << 5 + } + z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31) _, err := w.Write(z.scratch[0:2]) if err != nil { return nil, err } + if dict != nil { + // The next four bytes are the Adler-32 checksum of the dictionary. + checksum := adler32.Checksum(dict) + z.scratch[0] = uint8(checksum >> 24) + z.scratch[1] = uint8(checksum >> 16) + z.scratch[2] = uint8(checksum >> 8) + z.scratch[3] = uint8(checksum >> 0) + _, err = w.Write(z.scratch[0:4]) + if err != nil { + return nil, err + } + } z.w = w z.compressor = flate.NewWriter(w, level) z.digest = adler32.New() return z, nil } -func (z *writer) Write(p []byte) (n int, err os.Error) { +func (z *Writer) Write(p []byte) (n int, err os.Error) { if z.err != nil { return 0, z.err } @@ -86,8 +110,17 @@ func (z *writer) Write(p []byte) (n int, err os.Error) { return } +// Flush flushes the underlying compressor. +func (z *Writer) Flush() os.Error { + if z.err != nil { + return z.err + } + z.err = z.compressor.Flush() + return z.err +} + // Calling Close does not close the wrapped io.Writer originally passed to NewWriter. -func (z *writer) Close() os.Error { +func (z *Writer) Close() os.Error { if z.err != nil { return z.err } diff --git a/libgo/go/compress/zlib/writer_test.go b/libgo/go/compress/zlib/writer_test.go index 7eb1cd49497..f94f2847006 100644 --- a/libgo/go/compress/zlib/writer_test.go +++ b/libgo/go/compress/zlib/writer_test.go @@ -16,13 +16,19 @@ var filenames = []string{ "../testdata/pi.txt", } -// Tests that compressing and then decompressing the given file at the given compression level +// Tests that compressing and then decompressing the given file at the given compression level and dictionary // yields equivalent bytes to the original file. -func testFileLevel(t *testing.T, fn string, level int) { +func testFileLevelDict(t *testing.T, fn string, level int, d string) { + // Read dictionary, if given. + var dict []byte + if d != "" { + dict = []byte(d) + } + // Read the file, as golden output. golden, err := os.Open(fn) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer golden.Close() @@ -30,7 +36,7 @@ func testFileLevel(t *testing.T, fn string, level int) { // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end. raw, err := os.Open(fn) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } piper, pipew := io.Pipe() @@ -38,9 +44,9 @@ func testFileLevel(t *testing.T, fn string, level int) { go func() { defer raw.Close() defer pipew.Close() - zlibw, err := NewWriterLevel(pipew, level) + zlibw, err := NewWriterDict(pipew, level, dict) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer zlibw.Close() @@ -48,7 +54,7 @@ func testFileLevel(t *testing.T, fn string, level int) { for { n, err0 := raw.Read(b[0:]) if err0 != nil && err0 != os.EOF { - t.Errorf("%s (level=%d): %v", fn, level, err0) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) return } _, err1 := zlibw.Write(b[0:n]) @@ -57,7 +63,7 @@ func testFileLevel(t *testing.T, fn string, level int) { return } if err1 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err1) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) return } if err0 == os.EOF { @@ -65,9 +71,9 @@ func testFileLevel(t *testing.T, fn string, level int) { } } }() - zlibr, err := NewReader(piper) + zlibr, err := NewReaderDict(piper, dict) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer zlibr.Close() @@ -76,20 +82,20 @@ func testFileLevel(t *testing.T, fn string, level int) { b0, err0 := ioutil.ReadAll(golden) b1, err1 := ioutil.ReadAll(zlibr) if err0 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err0) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) return } if err1 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err1) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) return } if len(b0) != len(b1) { - t.Errorf("%s (level=%d): length mismatch %d versus %d", fn, level, len(b0), len(b1)) + t.Errorf("%s (level=%d, dict=%q): length mismatch %d versus %d", fn, level, d, len(b0), len(b1)) return } for i := 0; i < len(b0); i++ { if b0[i] != b1[i] { - t.Errorf("%s (level=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, i, b0[i], b1[i]) + t.Errorf("%s (level=%d, dict=%q): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, d, i, b0[i], b1[i]) return } } @@ -97,10 +103,21 @@ func testFileLevel(t *testing.T, fn string, level int) { func TestWriter(t *testing.T) { for _, fn := range filenames { - testFileLevel(t, fn, DefaultCompression) - testFileLevel(t, fn, NoCompression) + testFileLevelDict(t, fn, DefaultCompression, "") + testFileLevelDict(t, fn, NoCompression, "") + for level := BestSpeed; level <= BestCompression; level++ { + testFileLevelDict(t, fn, level, "") + } + } +} + +func TestWriterDict(t *testing.T) { + const dictionary = "0123456789." + for _, fn := range filenames { + testFileLevelDict(t, fn, DefaultCompression, dictionary) + testFileLevelDict(t, fn, NoCompression, dictionary) for level := BestSpeed; level <= BestCompression; level++ { - testFileLevel(t, fn, level) + testFileLevelDict(t, fn, level, dictionary) } } } diff --git a/libgo/go/container/heap/heap.go b/libgo/go/container/heap/heap.go index 4435a57c4e8..f2b8a750a45 100644 --- a/libgo/go/container/heap/heap.go +++ b/libgo/go/container/heap/heap.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides heap operations for any type that implements +// Package heap provides heap operations for any type that implements // heap.Interface. // package heap diff --git a/libgo/go/container/heap/heap_test.go b/libgo/go/container/heap/heap_test.go index 89d444dd546..5eb54374ab2 100644 --- a/libgo/go/container/heap/heap_test.go +++ b/libgo/go/container/heap/heap_test.go @@ -2,11 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package heap +package heap_test import ( "testing" "container/vector" + . "container/heap" ) diff --git a/libgo/go/container/list/list.go b/libgo/go/container/list/list.go index c1ebcddaa7a..a3fd4b39f32 100755 --- a/libgo/go/container/list/list.go +++ b/libgo/go/container/list/list.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The list package implements a doubly linked list. +// Package list implements a doubly linked list. // // To iterate over a list (where l is a *List): // for e := l.Front(); e != nil; e = e.Next() { diff --git a/libgo/go/container/ring/ring.go b/libgo/go/container/ring/ring.go index 5925164e9db..cc870ce9364 100644 --- a/libgo/go/container/ring/ring.go +++ b/libgo/go/container/ring/ring.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The ring package implements operations on circular lists. +// Package ring implements operations on circular lists. package ring // A Ring is an element of a circular list, or ring. diff --git a/libgo/go/container/vector/defs.go b/libgo/go/container/vector/defs.go index a2febb6deeb..bfb5481fb89 100644 --- a/libgo/go/container/vector/defs.go +++ b/libgo/go/container/vector/defs.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The vector package implements containers for managing sequences -// of elements. Vectors grow and shrink dynamically as necessary. +// Package vector implements containers for managing sequences of elements. +// Vectors grow and shrink dynamically as necessary. package vector diff --git a/libgo/go/crypto/aes/const.go b/libgo/go/crypto/aes/const.go index 97a5b64ec64..25acd0d1702 100644 --- a/libgo/go/crypto/aes/const.go +++ b/libgo/go/crypto/aes/const.go @@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// AES constants - 8720 bytes of initialized data. - -// This package implements AES encryption (formerly Rijndael), -// as defined in U.S. Federal Information Processing Standards Publication 197. +// Package aes implements AES encryption (formerly Rijndael), as defined in +// U.S. Federal Information Processing Standards Publication 197. package aes +// This file contains AES constants - 8720 bytes of initialized data. + // http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf // AES is based on the mathematical behavior of binary polynomials diff --git a/libgo/go/crypto/blowfish/cipher.go b/libgo/go/crypto/blowfish/cipher.go index 947f762d8bc..f3c5175acfa 100644 --- a/libgo/go/crypto/blowfish/cipher.go +++ b/libgo/go/crypto/blowfish/cipher.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements Bruce Schneier's Blowfish encryption algorithm. +// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. package blowfish // The code is a port of Bruce Schneier's C implementation. diff --git a/libgo/go/crypto/cast5/cast5.go b/libgo/go/crypto/cast5/cast5.go index 35f3e64b643..cb62e3132e8 100644 --- a/libgo/go/crypto/cast5/cast5.go +++ b/libgo/go/crypto/cast5/cast5.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements CAST5, as defined in RFC 2144. CAST5 is a common +// Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common // OpenPGP cipher. package cast5 diff --git a/libgo/go/crypto/cipher/cipher.go b/libgo/go/crypto/cipher/cipher.go index 50516b23a13..1ffaa8c2c33 100644 --- a/libgo/go/crypto/cipher/cipher.go +++ b/libgo/go/crypto/cipher/cipher.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The cipher package implements standard block cipher modes -// that can be wrapped around low-level block cipher implementations. +// Package cipher implements standard block cipher modes that can be wrapped +// around low-level block cipher implementations. // See http://csrc.nist.gov/groups/ST/toolkit/BCM/current_modes.html // and NIST Special Publication 800-38A. package cipher diff --git a/libgo/go/crypto/crypto.go b/libgo/go/crypto/crypto.go index be6b34adf28..53672a4da3c 100644 --- a/libgo/go/crypto/crypto.go +++ b/libgo/go/crypto/crypto.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The crypto package collects common cryptographic constants. +// Package crypto collects common cryptographic constants. package crypto import ( diff --git a/libgo/go/crypto/elliptic/elliptic.go b/libgo/go/crypto/elliptic/elliptic.go index 2296e960777..335c9645dc6 100644 --- a/libgo/go/crypto/elliptic/elliptic.go +++ b/libgo/go/crypto/elliptic/elliptic.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The elliptic package implements several standard elliptic curves over prime -// fields +// Package elliptic implements several standard elliptic curves over prime +// fields. package elliptic // This package operates, internally, on Jacobian coordinates. For a given diff --git a/libgo/go/crypto/hmac/hmac.go b/libgo/go/crypto/hmac/hmac.go index 298fb2c0694..04ec86e9ab1 100644 --- a/libgo/go/crypto/hmac/hmac.go +++ b/libgo/go/crypto/hmac/hmac.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The hmac package implements the Keyed-Hash Message Authentication Code (HMAC) -// as defined in U.S. Federal Information Processing Standards Publication 198. +// Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as +// defined in U.S. Federal Information Processing Standards Publication 198. // An HMAC is a cryptographic hash that uses a key to sign a message. // The receiver verifies the hash by recomputing it using the same key. package hmac diff --git a/libgo/go/crypto/md4/md4.go b/libgo/go/crypto/md4/md4.go index ee46544a920..848d9552df5 100644 --- a/libgo/go/crypto/md4/md4.go +++ b/libgo/go/crypto/md4/md4.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the MD4 hash algorithm as defined in RFC 1320. +// Package md4 implements the MD4 hash algorithm as defined in RFC 1320. package md4 import ( diff --git a/libgo/go/crypto/md5/md5.go b/libgo/go/crypto/md5/md5.go index 8f93fc4b354..378faa6ec71 100644 --- a/libgo/go/crypto/md5/md5.go +++ b/libgo/go/crypto/md5/md5.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the MD5 hash algorithm as defined in RFC 1321. +// Package md5 implements the MD5 hash algorithm as defined in RFC 1321. package md5 import ( diff --git a/libgo/go/crypto/ocsp/ocsp.go b/libgo/go/crypto/ocsp/ocsp.go index f42d8088884..acd75b8b06e 100644 --- a/libgo/go/crypto/ocsp/ocsp.go +++ b/libgo/go/crypto/ocsp/ocsp.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package parses OCSP responses as specified in RFC 2560. OCSP responses +// Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses // are signed messages attesting to the validity of a certificate for a small // period of time. This is used to manage revocation for X.509 certificates. package ocsp diff --git a/libgo/go/crypto/openpgp/armor/armor.go b/libgo/go/crypto/openpgp/armor/armor.go index 0c5ae9d716c..8da612c5007 100644 --- a/libgo/go/crypto/openpgp/armor/armor.go +++ b/libgo/go/crypto/openpgp/armor/armor.go @@ -2,15 +2,15 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is +// Package armor implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is // very similar to PEM except that it has an additional CRC checksum. package armor import ( + "bufio" "bytes" "crypto/openpgp/error" "encoding/base64" - "encoding/line" "io" "os" ) @@ -63,7 +63,7 @@ var armorEndOfLine = []byte("-----") // lineReader wraps a line based reader. It watches for the end of an armor // block and records the expected CRC value. type lineReader struct { - in *line.Reader + in *bufio.Reader buf []byte eof bool crc uint32 @@ -156,7 +156,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err os.Error) { // given Reader is not usable after calling this function: an arbitary amount // of data may have been read past the end of the block. func Decode(in io.Reader) (p *Block, err os.Error) { - r := line.NewReader(in, 100) + r, _ := bufio.NewReaderSize(in, 100) var line []byte ignoreNext := false diff --git a/libgo/go/crypto/openpgp/armor/encode.go b/libgo/go/crypto/openpgp/armor/encode.go index 0f7de024127..99dee375ef4 100644 --- a/libgo/go/crypto/openpgp/armor/encode.go +++ b/libgo/go/crypto/openpgp/armor/encode.go @@ -18,9 +18,9 @@ var armorEndOfLineOut = []byte("-----\n") // writeSlices writes its arguments to the given Writer. func writeSlices(out io.Writer, slices ...[]byte) (err os.Error) { for _, s := range slices { - _, err := out.Write(s) + _, err = out.Write(s) if err != nil { - return + return err } } return diff --git a/libgo/go/crypto/openpgp/error/error.go b/libgo/go/crypto/openpgp/error/error.go index 053d1596726..3759ce16122 100644 --- a/libgo/go/crypto/openpgp/error/error.go +++ b/libgo/go/crypto/openpgp/error/error.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package contains common error types for the OpenPGP packages. +// Package error contains common error types for the OpenPGP packages. package error import ( diff --git a/libgo/go/crypto/openpgp/keys.go b/libgo/go/crypto/openpgp/keys.go index ecaa86f2828..6c03f882831 100644 --- a/libgo/go/crypto/openpgp/keys.go +++ b/libgo/go/crypto/openpgp/keys.go @@ -5,6 +5,7 @@ package openpgp import ( + "crypto/openpgp/armor" "crypto/openpgp/error" "crypto/openpgp/packet" "io" @@ -13,6 +14,8 @@ import ( // PublicKeyType is the armor type for a PGP public key. var PublicKeyType = "PGP PUBLIC KEY BLOCK" +// PrivateKeyType is the armor type for a PGP private key. +var PrivateKeyType = "PGP PRIVATE KEY BLOCK" // An Entity represents the components of an OpenPGP key: a primary public key // (which must be a signing key), one or more identities claimed by that key, @@ -101,37 +104,50 @@ func (el EntityList) DecryptionKeys() (keys []Key) { // ReadArmoredKeyRing reads one or more public/private keys from an armor keyring file. func ReadArmoredKeyRing(r io.Reader) (EntityList, os.Error) { - body, err := readArmored(r, PublicKeyType) + block, err := armor.Decode(r) + if err == os.EOF { + return nil, error.InvalidArgumentError("no armored data found") + } if err != nil { return nil, err } + if block.Type != PublicKeyType && block.Type != PrivateKeyType { + return nil, error.InvalidArgumentError("expected public or private key block, got: " + block.Type) + } - return ReadKeyRing(body) + return ReadKeyRing(block.Body) } -// ReadKeyRing reads one or more public/private keys, ignoring unsupported keys. +// ReadKeyRing reads one or more public/private keys. Unsupported keys are +// ignored as long as at least a single valid key is found. func ReadKeyRing(r io.Reader) (el EntityList, err os.Error) { packets := packet.NewReader(r) + var lastUnsupportedError os.Error for { var e *Entity e, err = readEntity(packets) if err != nil { if _, ok := err.(error.UnsupportedError); ok { + lastUnsupportedError = err err = readToNextPublicKey(packets) } if err == os.EOF { err = nil - return + break } if err != nil { el = nil - return + break } } else { el = append(el, e) } } + + if len(el) == 0 && err == nil { + err = lastUnsupportedError + } return } @@ -197,25 +213,28 @@ EachPacket: current.Name = pkt.Id current.UserId = pkt e.Identities[pkt.Id] = current - p, err = packets.Next() - if err == os.EOF { - err = io.ErrUnexpectedEOF - } - if err != nil { - if _, ok := err.(error.UnsupportedError); ok { + + for { + p, err = packets.Next() + if err == os.EOF { + return nil, io.ErrUnexpectedEOF + } else if err != nil { return nil, err } - return nil, error.StructuralError("identity self-signature invalid: " + err.String()) - } - current.SelfSignature, ok = p.(*packet.Signature) - if !ok { - return nil, error.StructuralError("user ID packet not followed by self signature") - } - if current.SelfSignature.SigType != packet.SigTypePositiveCert { - return nil, error.StructuralError("user ID self-signature with wrong type") - } - if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, current.SelfSignature); err != nil { - return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) + + sig, ok := p.(*packet.Signature) + if !ok { + return nil, error.StructuralError("user ID packet not followed by self-signature") + } + + if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId { + if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil { + return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) + } + current.SelfSignature = sig + break + } + current.Signatures = append(current.Signatures, sig) } case *packet.Signature: if current == nil { diff --git a/libgo/go/crypto/openpgp/packet/packet.go b/libgo/go/crypto/openpgp/packet/packet.go index 57ff3afbfc7..c0ec44dd8ec 100644 --- a/libgo/go/crypto/openpgp/packet/packet.go +++ b/libgo/go/crypto/openpgp/packet/packet.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements parsing and serialisation of OpenPGP packets, as +// Package packet implements parsing and serialisation of OpenPGP packets, as // specified in RFC 4880. package packet diff --git a/libgo/go/crypto/openpgp/packet/private_key.go b/libgo/go/crypto/openpgp/packet/private_key.go index 69448239029..fde2a9933d8 100644 --- a/libgo/go/crypto/openpgp/packet/private_key.go +++ b/libgo/go/crypto/openpgp/packet/private_key.go @@ -164,8 +164,10 @@ func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) { } rsaPriv.D = new(big.Int).SetBytes(d) - rsaPriv.P = new(big.Int).SetBytes(p) - rsaPriv.Q = new(big.Int).SetBytes(q) + rsaPriv.Primes = make([]*big.Int, 2) + rsaPriv.Primes[0] = new(big.Int).SetBytes(p) + rsaPriv.Primes[1] = new(big.Int).SetBytes(q) + rsaPriv.Precompute() pk.PrivateKey = rsaPriv pk.Encrypted = false pk.encryptedData = nil diff --git a/libgo/go/crypto/openpgp/packet/public_key.go b/libgo/go/crypto/openpgp/packet/public_key.go index ebef481fb7f..cd4a9aebb60 100644 --- a/libgo/go/crypto/openpgp/packet/public_key.go +++ b/libgo/go/crypto/openpgp/packet/public_key.go @@ -15,6 +15,7 @@ import ( "hash" "io" "os" + "strconv" ) // PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2. @@ -47,7 +48,7 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) { case PubKeyAlgoDSA: err = pk.parseDSA(r) default: - err = error.UnsupportedError("public key type") + err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) } if err != nil { return diff --git a/libgo/go/crypto/openpgp/read.go b/libgo/go/crypto/openpgp/read.go index ac6998f0d24..4f84dff82bb 100644 --- a/libgo/go/crypto/openpgp/read.go +++ b/libgo/go/crypto/openpgp/read.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This openpgp package implements high level operations on OpenPGP messages. +// Package openpgp implements high level operations on OpenPGP messages. package openpgp import ( diff --git a/libgo/go/crypto/openpgp/read_test.go b/libgo/go/crypto/openpgp/read_test.go index 6218d9990dd..423c85b0f27 100644 --- a/libgo/go/crypto/openpgp/read_test.go +++ b/libgo/go/crypto/openpgp/read_test.go @@ -230,6 +230,23 @@ func TestDetachedSignatureDSA(t *testing.T) { testDetachedSignature(t, kring, readerFromHex(detachedSignatureDSAHex), signedInput, "binary", testKey3KeyId) } +func TestReadingArmoredPrivateKey(t *testing.T) { + el, err := ReadArmoredKeyRing(bytes.NewBufferString(armoredPrivateKeyBlock)) + if err != nil { + t.Error(err) + } + if len(el) != 1 { + t.Errorf("got %d entities, wanted 1\n", len(el)) + } +} + +func TestNoArmoredData(t *testing.T) { + _, err := ReadArmoredKeyRing(bytes.NewBufferString("foo")) + if _, ok := err.(error.InvalidArgumentError); !ok { + t.Errorf("error was not an InvalidArgumentError: %s", err) + } +} + const testKey1KeyId = 0xA34D7E18C20C31BB const testKey3KeyId = 0x338934250CCC0360 @@ -259,3 +276,37 @@ const symmetricallyEncryptedCompressedHex = "8c0d04030302eb4a03808145d0d260c92f7 const dsaTestKeyHex = "9901a2044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" const dsaTestKeyPrivateHex = "9501bb044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4d00009f592e0619d823953577d4503061706843317e4fee083db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" + +const armoredPrivateKeyBlock = `-----BEGIN PGP PRIVATE KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +lQHYBE2rFNoBBADFwqWQIW/DSqcB4yCQqnAFTJ27qS5AnB46ccAdw3u4Greeu3Bp +idpoHdjULy7zSKlwR1EA873dO/k/e11Ml3dlAFUinWeejWaK2ugFP6JjiieSsrKn +vWNicdCS4HTWn0X4sjl0ZiAygw6GNhqEQ3cpLeL0g8E9hnYzJKQ0LWJa0QARAQAB +AAP/TB81EIo2VYNmTq0pK1ZXwUpxCrvAAIG3hwKjEzHcbQznsjNvPUihZ+NZQ6+X +0HCfPAdPkGDCLCb6NavcSW+iNnLTrdDnSI6+3BbIONqWWdRDYJhqZCkqmG6zqSfL +IdkJgCw94taUg5BWP/AAeQrhzjChvpMQTVKQL5mnuZbUCeMCAN5qrYMP2S9iKdnk +VANIFj7656ARKt/nf4CBzxcpHTyB8+d2CtPDKCmlJP6vL8t58Jmih+kHJMvC0dzn +gr5f5+sCAOOe5gt9e0am7AvQWhdbHVfJU0TQJx+m2OiCJAqGTB1nvtBLHdJnfdC9 +TnXXQ6ZXibqLyBies/xeY2sCKL5qtTMCAKnX9+9d/5yQxRyrQUHt1NYhaXZnJbHx +q4ytu0eWz+5i68IYUSK69jJ1NWPM0T6SkqpB3KCAIv68VFm9PxqG1KmhSrQIVGVz +dCBLZXmIuAQTAQIAIgUCTasU2gIbAwYLCQgHAwIGFQgCCQoLBBYCAwECHgECF4AA +CgkQO9o98PRieSoLhgQAkLEZex02Qt7vGhZzMwuN0R22w3VwyYyjBx+fM3JFETy1 +ut4xcLJoJfIaF5ZS38UplgakHG0FQ+b49i8dMij0aZmDqGxrew1m4kBfjXw9B/v+ +eIqpODryb6cOSwyQFH0lQkXC040pjq9YqDsO5w0WYNXYKDnzRV0p4H1pweo2VDid +AdgETasU2gEEAN46UPeWRqKHvA99arOxee38fBt2CI08iiWyI8T3J6ivtFGixSqV +bRcPxYO/qLpVe5l84Nb3X71GfVXlc9hyv7CD6tcowL59hg1E/DC5ydI8K8iEpUmK +/UnHdIY5h8/kqgGxkY/T/hgp5fRQgW1ZoZxLajVlMRZ8W4tFtT0DeA+JABEBAAEA +A/0bE1jaaZKj6ndqcw86jd+QtD1SF+Cf21CWRNeLKnUds4FRRvclzTyUMuWPkUeX +TaNNsUOFqBsf6QQ2oHUBBK4VCHffHCW4ZEX2cd6umz7mpHW6XzN4DECEzOVksXtc +lUC1j4UB91DC/RNQqwX1IV2QLSwssVotPMPqhOi0ZLNY7wIA3n7DWKInxYZZ4K+6 +rQ+POsz6brEoRHwr8x6XlHenq1Oki855pSa1yXIARoTrSJkBtn5oI+f8AzrnN0BN +oyeQAwIA/7E++3HDi5aweWrViiul9cd3rcsS0dEnksPhvS0ozCJiHsq/6GFmy7J8 +QSHZPteedBnZyNp5jR+H7cIfVN3KgwH/Skq4PsuPhDq5TKK6i8Pc1WW8MA6DXTdU +nLkX7RGmMwjC0DBf7KWAlPjFaONAX3a8ndnz//fy1q7u2l9AZwrj1qa1iJ8EGAEC +AAkFAk2rFNoCGwwACgkQO9o98PRieSo2/QP/WTzr4ioINVsvN1akKuekmEMI3LAp +BfHwatufxxP1U+3Si/6YIk7kuPB9Hs+pRqCXzbvPRrI8NHZBmc8qIGthishdCYad +AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL +VrM0m72/jnpKo04= +=zNCn +-----END PGP PRIVATE KEY BLOCK-----` diff --git a/libgo/go/crypto/openpgp/s2k/s2k.go b/libgo/go/crypto/openpgp/s2k/s2k.go index 873b33dc0d5..93b7582fa06 100644 --- a/libgo/go/crypto/openpgp/s2k/s2k.go +++ b/libgo/go/crypto/openpgp/s2k/s2k.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the various OpenPGP string-to-key transforms as +// Package s2k implements the various OpenPGP string-to-key transforms as // specified in RFC 4800 section 3.7.1. package s2k diff --git a/libgo/go/crypto/rc4/rc4.go b/libgo/go/crypto/rc4/rc4.go index 65fd195f3de..7ee471093b4 100644 --- a/libgo/go/crypto/rc4/rc4.go +++ b/libgo/go/crypto/rc4/rc4.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements RC4 encryption, as defined in Bruce Schneier's +// Package rc4 implements RC4 encryption, as defined in Bruce Schneier's // Applied Cryptography. package rc4 diff --git a/libgo/go/crypto/ripemd160/ripemd160.go b/libgo/go/crypto/ripemd160/ripemd160.go index 6e88521c3f6..5aaca59a3cf 100644 --- a/libgo/go/crypto/ripemd160/ripemd160.go +++ b/libgo/go/crypto/ripemd160/ripemd160.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the RIPEMD-160 hash algorithm. +// Package ripemd160 implements the RIPEMD-160 hash algorithm. package ripemd160 // RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart diff --git a/libgo/go/crypto/rsa/pkcs1v15.go b/libgo/go/crypto/rsa/pkcs1v15.go index 9a7184127db..3defa62ea6d 100644 --- a/libgo/go/crypto/rsa/pkcs1v15.go +++ b/libgo/go/crypto/rsa/pkcs1v15.go @@ -149,10 +149,10 @@ func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) { // precompute a prefix of the digest value that makes a valid ASN1 DER string // with the correct contents. var hashPrefixes = map[crypto.Hash][]byte{ - crypto.MD5: []byte{0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10}, - crypto.SHA1: []byte{0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14}, - crypto.SHA256: []byte{0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20}, - crypto.SHA384: []byte{0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30}, + crypto.MD5: {0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10}, + crypto.SHA1: {0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14}, + crypto.SHA256: {0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20}, + crypto.SHA384: {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30}, crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, crypto.MD5SHA1: {}, // A special TLS case which doesn't use an ASN1 prefix. crypto.RIPEMD160: {0x30, 0x20, 0x30, 0x08, 0x06, 0x06, 0x28, 0xcf, 0x06, 0x03, 0x00, 0x31, 0x04, 0x14}, diff --git a/libgo/go/crypto/rsa/pkcs1v15_test.go b/libgo/go/crypto/rsa/pkcs1v15_test.go index 30a4824a6b0..d69bacfd685 100644 --- a/libgo/go/crypto/rsa/pkcs1v15_test.go +++ b/libgo/go/crypto/rsa/pkcs1v15_test.go @@ -197,12 +197,6 @@ func TestVerifyPKCS1v15(t *testing.T) { } } -func bigFromString(s string) *big.Int { - ret := new(big.Int) - ret.SetString(s, 10) - return ret -} - // In order to generate new test vectors you'll need the PEM form of this key: // -----BEGIN RSA PRIVATE KEY----- // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 @@ -216,10 +210,12 @@ func bigFromString(s string) *big.Int { var rsaPrivateKey = &PrivateKey{ PublicKey: PublicKey{ - N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), + N: fromBase10("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), E: 65537, }, - D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + D: fromBase10("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), + Primes: []*big.Int{ + fromBase10("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + fromBase10("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + }, } diff --git a/libgo/go/crypto/rsa/rsa.go b/libgo/go/crypto/rsa/rsa.go index b3b212c2066..e1813dbf938 100644 --- a/libgo/go/crypto/rsa/rsa.go +++ b/libgo/go/crypto/rsa/rsa.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements RSA encryption as specified in PKCS#1. +// Package rsa implements RSA encryption as specified in PKCS#1. package rsa // TODO(agl): Add support for PSS padding. @@ -13,7 +13,6 @@ import ( "hash" "io" "os" - "sync" ) var bigZero = big.NewInt(0) @@ -90,50 +89,60 @@ type PublicKey struct { // A PrivateKey represents an RSA key type PrivateKey struct { - PublicKey // public part. - D *big.Int // private exponent - P, Q, R *big.Int // prime factors of N (R may be nil) - - rwMutex sync.RWMutex // protects the following - dP, dQ, dR *big.Int // D mod (P-1) (or mod Q-1 etc) - qInv *big.Int // q^-1 mod p - pq *big.Int // P*Q - tr *big.Int // pq·tr ≡ 1 mod r + PublicKey // public part. + D *big.Int // private exponent + Primes []*big.Int // prime factors of N, has >= 2 elements. + + // Precomputed contains precomputed values that speed up private + // operations, if availible. + Precomputed PrecomputedValues +} + +type PrecomputedValues struct { + Dp, Dq *big.Int // D mod (P-1) (or mod Q-1) + Qinv *big.Int // Q^-1 mod Q + + // CRTValues is used for the 3rd and subsequent primes. Due to a + // historical accident, the CRT for the first two primes is handled + // differently in PKCS#1 and interoperability is sufficiently + // important that we mirror this. + CRTValues []CRTValue +} + +// CRTValue contains the precomputed chinese remainder theorem values. +type CRTValue struct { + Exp *big.Int // D mod (prime-1). + Coeff *big.Int // R·Coeff ≡ 1 mod Prime. + R *big.Int // product of primes prior to this (inc p and q). } // Validate performs basic sanity checks on the key. // It returns nil if the key is valid, or else an os.Error describing a problem. func (priv *PrivateKey) Validate() os.Error { - // Check that p, q and, maybe, r are prime. Note that this is just a - // sanity check. Since the random witnesses chosen by ProbablyPrime are - // deterministic, given the candidate number, it's easy for an attack - // to generate composites that pass this test. - if !big.ProbablyPrime(priv.P, 20) { - return os.ErrorString("P is composite") - } - if !big.ProbablyPrime(priv.Q, 20) { - return os.ErrorString("Q is composite") - } - if priv.R != nil && !big.ProbablyPrime(priv.R, 20) { - return os.ErrorString("R is composite") + // Check that the prime factors are actually prime. Note that this is + // just a sanity check. Since the random witnesses chosen by + // ProbablyPrime are deterministic, given the candidate number, it's + // easy for an attack to generate composites that pass this test. + for _, prime := range priv.Primes { + if !big.ProbablyPrime(prime, 20) { + return os.ErrorString("Prime factor is composite") + } } - // Check that p*q*r == n. - modulus := new(big.Int).Mul(priv.P, priv.Q) - if priv.R != nil { - modulus.Mul(modulus, priv.R) + // Check that Πprimes == n. + modulus := new(big.Int).Set(bigOne) + for _, prime := range priv.Primes { + modulus.Mul(modulus, prime) } if modulus.Cmp(priv.N) != 0 { return os.ErrorString("invalid modulus") } - // Check that e and totient(p, q, r) are coprime. - pminus1 := new(big.Int).Sub(priv.P, bigOne) - qminus1 := new(big.Int).Sub(priv.Q, bigOne) - totient := new(big.Int).Mul(pminus1, qminus1) - if priv.R != nil { - rminus1 := new(big.Int).Sub(priv.R, bigOne) - totient.Mul(totient, rminus1) + // Check that e and totient(Πprimes) are coprime. + totient := new(big.Int).Set(bigOne) + for _, prime := range priv.Primes { + pminus1 := new(big.Int).Sub(prime, bigOne) + totient.Mul(totient, pminus1) } e := big.NewInt(int64(priv.E)) gcd := new(big.Int) @@ -143,7 +152,7 @@ func (priv *PrivateKey) Validate() os.Error { if gcd.Cmp(bigOne) != 0 { return os.ErrorString("invalid public exponent E") } - // Check that de ≡ 1 (mod totient(p, q, r)) + // Check that de ≡ 1 (mod totient(Πprimes)) de := new(big.Int).Mul(priv.D, e) de.Mod(de, totient) if de.Cmp(bigOne) != 0 { @@ -154,6 +163,20 @@ func (priv *PrivateKey) Validate() os.Error { // GenerateKey generates an RSA keypair of the given bit size. func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { + return GenerateMultiPrimeKey(rand, 2, bits) +} + +// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit +// size, as suggested in [1]. Although the public keys are compatible +// (actually, indistinguishable) from the 2-prime case, the private keys are +// not. Thus it may not be possible to export multi-prime private keys in +// certain formats or to subsequently import them into other code. +// +// Table 1 in [2] suggests maximum numbers of primes for a given size. +// +// [1] US patent 4405829 (1972, expired) +// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf +func GenerateMultiPrimeKey(rand io.Reader, nprimes int, bits int) (priv *PrivateKey, err os.Error) { priv = new(PrivateKey) // Smaller public exponents lead to faster public key // operations. Since the exponent must be coprime to @@ -165,100 +188,41 @@ func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { // [1] http://marc.info/?l=cryptography&m=115694833312008&w=2 priv.E = 3 - pminus1 := new(big.Int) - qminus1 := new(big.Int) - totient := new(big.Int) - - for { - p, err := randomPrime(rand, bits/2) - if err != nil { - return nil, err - } - - q, err := randomPrime(rand, bits/2) - if err != nil { - return nil, err - } - - if p.Cmp(q) == 0 { - continue - } - - n := new(big.Int).Mul(p, q) - pminus1.Sub(p, bigOne) - qminus1.Sub(q, bigOne) - totient.Mul(pminus1, qminus1) - - g := new(big.Int) - priv.D = new(big.Int) - y := new(big.Int) - e := big.NewInt(int64(priv.E)) - big.GcdInt(g, priv.D, y, e, totient) - - if g.Cmp(bigOne) == 0 { - priv.D.Add(priv.D, totient) - priv.P = p - priv.Q = q - priv.N = n - - break - } + if nprimes < 2 { + return nil, os.ErrorString("rsa.GenerateMultiPrimeKey: nprimes must be >= 2") } - return -} - -// Generate3PrimeKey generates a 3-prime RSA keypair of the given bit size, as -// suggested in [1]. Although the public keys are compatible (actually, -// indistinguishable) from the 2-prime case, the private keys are not. Thus it -// may not be possible to export 3-prime private keys in certain formats or to -// subsequently import them into other code. -// -// Table 1 in [2] suggests that size should be >= 1024 when using 3 primes. -// -// [1] US patent 4405829 (1972, expired) -// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf -func Generate3PrimeKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { - priv = new(PrivateKey) - priv.E = 3 - - pminus1 := new(big.Int) - qminus1 := new(big.Int) - rminus1 := new(big.Int) - totient := new(big.Int) + primes := make([]*big.Int, nprimes) +NextSetOfPrimes: for { - p, err := randomPrime(rand, bits/3) - if err != nil { - return nil, err - } - - todo := bits - p.BitLen() - q, err := randomPrime(rand, todo/2) - if err != nil { - return nil, err + todo := bits + for i := 0; i < nprimes; i++ { + primes[i], err = randomPrime(rand, todo/(nprimes-i)) + if err != nil { + return nil, err + } + todo -= primes[i].BitLen() } - todo -= q.BitLen() - r, err := randomPrime(rand, todo) - if err != nil { - return nil, err + // Make sure that primes is pairwise unequal. + for i, prime := range primes { + for j := 0; j < i; j++ { + if prime.Cmp(primes[j]) == 0 { + continue NextSetOfPrimes + } + } } - if p.Cmp(q) == 0 || - q.Cmp(r) == 0 || - r.Cmp(p) == 0 { - continue + n := new(big.Int).Set(bigOne) + totient := new(big.Int).Set(bigOne) + pminus1 := new(big.Int) + for _, prime := range primes { + n.Mul(n, prime) + pminus1.Sub(prime, bigOne) + totient.Mul(totient, pminus1) } - n := new(big.Int).Mul(p, q) - n.Mul(n, r) - pminus1.Sub(p, bigOne) - qminus1.Sub(q, bigOne) - rminus1.Sub(r, bigOne) - totient.Mul(pminus1, qminus1) - totient.Mul(totient, rminus1) - g := new(big.Int) priv.D = new(big.Int) y := new(big.Int) @@ -267,15 +231,14 @@ func Generate3PrimeKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error if g.Cmp(bigOne) == 0 { priv.D.Add(priv.D, totient) - priv.P = p - priv.Q = q - priv.R = r + priv.Primes = primes priv.N = n break } } + priv.Precompute() return } @@ -409,23 +372,34 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { return x, true } -// precompute performs some calculations that speed up private key operations +// Precompute performs some calculations that speed up private key operations // in the future. -func (priv *PrivateKey) precompute() { - priv.dP = new(big.Int).Sub(priv.P, bigOne) - priv.dP.Mod(priv.D, priv.dP) +func (priv *PrivateKey) Precompute() { + if priv.Precomputed.Dp != nil { + return + } - priv.dQ = new(big.Int).Sub(priv.Q, bigOne) - priv.dQ.Mod(priv.D, priv.dQ) + priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne) + priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp) - priv.qInv = new(big.Int).ModInverse(priv.Q, priv.P) + priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne) + priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq) - if priv.R != nil { - priv.dR = new(big.Int).Sub(priv.R, bigOne) - priv.dR.Mod(priv.D, priv.dR) + priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0]) - priv.pq = new(big.Int).Mul(priv.P, priv.Q) - priv.tr = new(big.Int).ModInverse(priv.pq, priv.R) + r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1]) + priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2) + for i := 2; i < len(priv.Primes); i++ { + prime := priv.Primes[i] + values := &priv.Precomputed.CRTValues[i-2] + + values.Exp = new(big.Int).Sub(prime, bigOne) + values.Exp.Mod(priv.D, values.Exp) + + values.R = new(big.Int).Set(r) + values.Coeff = new(big.Int).ModInverse(r, prime) + + r.Mul(r, prime) } } @@ -463,53 +437,41 @@ func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.E } bigE := big.NewInt(int64(priv.E)) rpowe := new(big.Int).Exp(r, bigE, priv.N) - c.Mul(c, rpowe) - c.Mod(c, priv.N) - } - - priv.rwMutex.RLock() - - if priv.dP == nil && priv.P != nil { - priv.rwMutex.RUnlock() - priv.rwMutex.Lock() - if priv.dP == nil && priv.P != nil { - priv.precompute() - } - priv.rwMutex.Unlock() - priv.rwMutex.RLock() + cCopy := new(big.Int).Set(c) + cCopy.Mul(cCopy, rpowe) + cCopy.Mod(cCopy, priv.N) + c = cCopy } - if priv.dP == nil { + if priv.Precomputed.Dp == nil { m = new(big.Int).Exp(c, priv.D, priv.N) } else { // We have the precalculated values needed for the CRT. - m = new(big.Int).Exp(c, priv.dP, priv.P) - m2 := new(big.Int).Exp(c, priv.dQ, priv.Q) + m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) + m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) m.Sub(m, m2) if m.Sign() < 0 { - m.Add(m, priv.P) + m.Add(m, priv.Primes[0]) } - m.Mul(m, priv.qInv) - m.Mod(m, priv.P) - m.Mul(m, priv.Q) + m.Mul(m, priv.Precomputed.Qinv) + m.Mod(m, priv.Primes[0]) + m.Mul(m, priv.Primes[1]) m.Add(m, m2) - if priv.dR != nil { - // 3-prime CRT. - m2.Exp(c, priv.dR, priv.R) + for i, values := range priv.Precomputed.CRTValues { + prime := priv.Primes[2+i] + m2.Exp(c, values.Exp, prime) m2.Sub(m2, m) - m2.Mul(m2, priv.tr) - m2.Mod(m2, priv.R) + m2.Mul(m2, values.Coeff) + m2.Mod(m2, prime) if m2.Sign() < 0 { - m2.Add(m2, priv.R) + m2.Add(m2, prime) } - m2.Mul(m2, priv.pq) + m2.Mul(m2, values.R) m.Add(m, m2) } } - priv.rwMutex.RUnlock() - if ir != nil { // Unblind. m.Mul(m, ir) diff --git a/libgo/go/crypto/rsa/rsa_test.go b/libgo/go/crypto/rsa/rsa_test.go index d8a936eb68f..c36bca1cd37 100644 --- a/libgo/go/crypto/rsa/rsa_test.go +++ b/libgo/go/crypto/rsa/rsa_test.go @@ -30,7 +30,20 @@ func Test3PrimeKeyGeneration(t *testing.T) { } size := 768 - priv, err := Generate3PrimeKey(rand.Reader, size) + priv, err := GenerateMultiPrimeKey(rand.Reader, 3, size) + if err != nil { + t.Errorf("failed to generate key") + } + testKeyBasics(t, priv) +} + +func Test4PrimeKeyGeneration(t *testing.T) { + if testing.Short() { + return + } + + size := 768 + priv, err := GenerateMultiPrimeKey(rand.Reader, 4, size) if err != nil { t.Errorf("failed to generate key") } @@ -45,6 +58,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { pub := &priv.PublicKey m := big.NewInt(42) c := encrypt(new(big.Int), pub, m) + m2, err := decrypt(nil, priv, c) if err != nil { t.Errorf("error while decrypting: %s", err) @@ -59,7 +73,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { t.Errorf("error while decrypting (blind): %s", err) } if m.Cmp(m3) != 0 { - t.Errorf("(blind) got:%v, want:%v", m3, m) + t.Errorf("(blind) got:%v, want:%v (%#v)", m3, m, priv) } } @@ -77,10 +91,12 @@ func BenchmarkRSA2048Decrypt(b *testing.B) { E: 3, }, D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"), - P: fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), - Q: fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), + Primes: []*big.Int{ + fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), + fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), + }, } - priv.precompute() + priv.Precompute() c := fromBase10("1000") @@ -99,11 +115,13 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) { E: 3, }, D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), - P: fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), - Q: fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), - R: fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + Primes: []*big.Int{ + fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), + fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), + fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + }, } - priv.precompute() + priv.Precompute() c := fromBase10("1000") diff --git a/libgo/go/crypto/sha1/sha1.go b/libgo/go/crypto/sha1/sha1.go index e6aa096e2a6..788d1ff5552 100644 --- a/libgo/go/crypto/sha1/sha1.go +++ b/libgo/go/crypto/sha1/sha1.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA1 hash algorithm as defined in RFC 3174. +// Package sha1 implements the SHA1 hash algorithm as defined in RFC 3174. package sha1 import ( diff --git a/libgo/go/crypto/sha256/sha256.go b/libgo/go/crypto/sha256/sha256.go index 69b356b4e51..a2c058d180e 100644 --- a/libgo/go/crypto/sha256/sha256.go +++ b/libgo/go/crypto/sha256/sha256.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA224 and SHA256 hash algorithms as defined in FIPS 180-2. +// Package sha256 implements the SHA224 and SHA256 hash algorithms as defined +// in FIPS 180-2. package sha256 import ( diff --git a/libgo/go/crypto/sha512/sha512.go b/libgo/go/crypto/sha512/sha512.go index 7e9f330e594..78f5fe26f80 100644 --- a/libgo/go/crypto/sha512/sha512.go +++ b/libgo/go/crypto/sha512/sha512.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA384 and SHA512 hash algorithms as defined in FIPS 180-2. +// Package sha512 implements the SHA384 and SHA512 hash algorithms as defined +// in FIPS 180-2. package sha512 import ( diff --git a/libgo/go/crypto/subtle/constant_time.go b/libgo/go/crypto/subtle/constant_time.go index a3d70b9c96e..57dbe9db555 100644 --- a/libgo/go/crypto/subtle/constant_time.go +++ b/libgo/go/crypto/subtle/constant_time.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements functions that are often useful in cryptographic +// Package subtle implements functions that are often useful in cryptographic // code but require careful thought to use correctly. package subtle diff --git a/libgo/go/crypto/tls/ca_set.go b/libgo/go/crypto/tls/ca_set.go deleted file mode 100644 index ae00ac55868..00000000000 --- a/libgo/go/crypto/tls/ca_set.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto/x509" - "encoding/pem" - "strings" -) - -// A CASet is a set of certificates. -type CASet struct { - bySubjectKeyId map[string][]*x509.Certificate - byName map[string][]*x509.Certificate -} - -// NewCASet returns a new, empty CASet. -func NewCASet() *CASet { - return &CASet{ - make(map[string][]*x509.Certificate), - make(map[string][]*x509.Certificate), - } -} - -func nameToKey(name *x509.Name) string { - return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName -} - -// FindVerifiedParent attempts to find the certificate in s which has signed -// the given certificate. If no such certificate can be found or the signature -// doesn't match, it returns nil. -func (s *CASet) FindVerifiedParent(cert *x509.Certificate) (parent *x509.Certificate) { - var candidates []*x509.Certificate - - if len(cert.AuthorityKeyId) > 0 { - candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] - } - if len(candidates) == 0 { - candidates = s.byName[nameToKey(&cert.Issuer)] - } - - for _, c := range candidates { - if cert.CheckSignatureFrom(c) == nil { - return c - } - } - - return nil -} - -// AddCert adds a certificate to the set -func (s *CASet) AddCert(cert *x509.Certificate) { - if len(cert.SubjectKeyId) > 0 { - keyId := string(cert.SubjectKeyId) - s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], cert) - } - name := nameToKey(&cert.Subject) - s.byName[name] = append(s.byName[name], cert) -} - -// SetFromPEM attempts to parse a series of PEM encoded root certificates. It -// appends any certificates found to s and returns true if any certificates -// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will -// contains the system wide set of root CAs in a format suitable for this -// function. -func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { - for len(pemCerts) > 0 { - var block *pem.Block - block, pemCerts = pem.Decode(pemCerts) - if block == nil { - break - } - if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { - continue - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - continue - } - - s.AddCert(cert) - ok = true - } - - return -} diff --git a/libgo/go/crypto/tls/common.go b/libgo/go/crypto/tls/common.go index c7792343942..0b26aae84d1 100644 --- a/libgo/go/crypto/tls/common.go +++ b/libgo/go/crypto/tls/common.go @@ -100,6 +100,8 @@ type ConnectionState struct { // the certificate chain that was presented by the other side PeerCertificates []*x509.Certificate + // the verified certificate chains built from PeerCertificates. + VerifiedChains [][]*x509.Certificate } // A Config structure is used to configure a TLS client or server. After one @@ -122,7 +124,7 @@ type Config struct { // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *CASet + RootCAs *x509.CertPool // NextProtos is a list of supported, application level protocols. NextProtos []string @@ -158,7 +160,7 @@ func (c *Config) time() int64 { return t() } -func (c *Config) rootCAs() *CASet { +func (c *Config) rootCAs() *x509.CertPool { s := c.RootCAs if s == nil { s = defaultRoots() @@ -178,6 +180,9 @@ func (c *Config) cipherSuites() []uint16 { type Certificate struct { Certificate [][]byte PrivateKey *rsa.PrivateKey + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte } // A TLS record. @@ -221,7 +226,7 @@ var certFiles = []string{ var once sync.Once -func defaultRoots() *CASet { +func defaultRoots() *x509.CertPool { once.Do(initDefaults) return varDefaultRoots } @@ -236,14 +241,14 @@ func initDefaults() { initDefaultCipherSuites() } -var varDefaultRoots *CASet +var varDefaultRoots *x509.CertPool func initDefaultRoots() { - roots := NewCASet() + roots := x509.NewCertPool() for _, file := range certFiles { data, err := ioutil.ReadFile(file) if err == nil { - roots.SetFromPEM(data) + roots.AppendCertsFromPEM(data) break } } @@ -255,7 +260,7 @@ var varDefaultCipherSuites []uint16 func initDefaultCipherSuites() { varDefaultCipherSuites = make([]uint16, len(cipherSuites)) i := 0 - for id, _ := range cipherSuites { + for id := range cipherSuites { varDefaultCipherSuites[i] = id i++ } diff --git a/libgo/go/crypto/tls/conn.go b/libgo/go/crypto/tls/conn.go index b94e235c814..48d3f725b49 100644 --- a/libgo/go/crypto/tls/conn.go +++ b/libgo/go/crypto/tls/conn.go @@ -34,6 +34,9 @@ type Conn struct { cipherSuite uint16 ocspResponse []byte // stapled OCSP response peerCertificates []*x509.Certificate + // verifedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate clientProtocol string clientProtocolFallback bool @@ -765,6 +768,7 @@ func (c *Conn) ConnectionState() ConnectionState { state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.CipherSuite = c.cipherSuite state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains } return state diff --git a/libgo/go/crypto/tls/handshake_client.go b/libgo/go/crypto/tls/handshake_client.go index 540b25c8753..c758c96d4ef 100644 --- a/libgo/go/crypto/tls/handshake_client.go +++ b/libgo/go/crypto/tls/handshake_client.go @@ -88,7 +88,6 @@ func (c *Conn) clientHandshake() os.Error { finishedHash.Write(certMsg.marshal()) certs := make([]*x509.Certificate, len(certMsg.certificates)) - chain := NewCASet() for i, asn1Data := range certMsg.certificates { cert, err := x509.ParseCertificate(asn1Data) if err != nil { @@ -96,47 +95,29 @@ func (c *Conn) clientHandshake() os.Error { return os.ErrorString("failed to parse certificate from server: " + err.String()) } certs[i] = cert - chain.AddCert(cert) } // If we don't have a root CA set configured then anything is accepted. // TODO(rsc): Find certificates for OS X 10.6. - for cur := certs[0]; c.config.RootCAs != nil; { - parent := c.config.RootCAs.FindVerifiedParent(cur) - if parent != nil { - break + if c.config.RootCAs != nil { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), } - parent = chain.FindVerifiedParent(cur) - if parent == nil { - c.sendAlert(alertBadCertificate) - return os.ErrorString("could not find root certificate for chain") + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) } - - if !parent.BasicConstraintsValid || !parent.IsCA { + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { c.sendAlert(alertBadCertificate) - return os.ErrorString("intermediate certificate does not have CA bit set") + return err } - // KeyUsage status flags are ignored. From Engineering - // Security, Peter Gutmann: A European government CA marked its - // signing certificates as being valid for encryption only, but - // no-one noticed. Another European CA marked its signature - // keys as not being valid for signatures. A different CA - // marked its own trusted root certificate as being invalid for - // certificate signing. Another national CA distributed a - // certificate to be used to encrypt data for the country’s tax - // authority that was marked as only being usable for digital - // signatures but not for encryption. Yet another CA reversed - // the order of the bit flags in the keyUsage due to confusion - // over encoding endianness, essentially setting a random - // keyUsage in certificates that it issued. Another CA created - // a self-invalidating certificate by adding a certificate - // policy statement stipulating that the certificate had to be - // used strictly as specified in the keyUsage, and a keyUsage - // containing a flag indicating that the RSA encryption key - // could only be used for Diffie-Hellman key agreement. - - cur = parent } if _, ok := certs[0].PublicKey.(*rsa.PublicKey); !ok { @@ -145,7 +126,7 @@ func (c *Conn) clientHandshake() os.Error { c.peerCertificates = certs - if serverHello.certStatus { + if serverHello.ocspStapling { msg, err = c.readHandshake() if err != nil { return err diff --git a/libgo/go/crypto/tls/handshake_messages.go b/libgo/go/crypto/tls/handshake_messages.go index e5e8562713d..6645adce4f2 100644 --- a/libgo/go/crypto/tls/handshake_messages.go +++ b/libgo/go/crypto/tls/handshake_messages.go @@ -306,7 +306,7 @@ type serverHelloMsg struct { compressionMethod uint8 nextProtoNeg bool nextProtos []string - certStatus bool + ocspStapling bool } func (m *serverHelloMsg) marshal() []byte { @@ -327,7 +327,7 @@ func (m *serverHelloMsg) marshal() []byte { nextProtoLen += len(m.nextProtos) extensionsLength += nextProtoLen } - if m.certStatus { + if m.ocspStapling { numExtensions++ } if numExtensions > 0 { @@ -373,7 +373,7 @@ func (m *serverHelloMsg) marshal() []byte { z = z[1+l:] } } - if m.certStatus { + if m.ocspStapling { z[0] = byte(extensionStatusRequest >> 8) z[1] = byte(extensionStatusRequest) z = z[4:] @@ -406,7 +406,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtoNeg = false m.nextProtos = nil - m.certStatus = false + m.ocspStapling = false if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -450,7 +450,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { if length > 0 { return false } - m.certStatus = true + m.ocspStapling = true } data = data[length:] } diff --git a/libgo/go/crypto/tls/handshake_messages_test.go b/libgo/go/crypto/tls/handshake_messages_test.go index 0b93b89f1ad..23f729dd94b 100644 --- a/libgo/go/crypto/tls/handshake_messages_test.go +++ b/libgo/go/crypto/tls/handshake_messages_test.go @@ -32,7 +32,7 @@ type testMessage interface { func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) for i, iface := range tests { - ty := reflect.NewValue(iface).Type() + ty := reflect.ValueOf(iface).Type() n := 100 if testing.Short() { @@ -121,11 +121,11 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.ocspStapling = rand.Intn(10) > 5 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) m.supportedCurves = make([]uint16, rand.Intn(5)+1) - for i, _ := range m.supportedCurves { + for i := range m.supportedCurves { m.supportedCurves[i] = uint16(rand.Intn(30000)) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -146,7 +146,7 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -156,7 +156,7 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { for i := 0; i < numCerts; i++ { m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -167,13 +167,13 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value for i := 0; i < numCAs; i++ { m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateVerifyMsg{} m.signature = randomBytes(rand.Intn(15)+1, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -184,23 +184,23 @@ func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { } else { m.statusType = 42 } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &clientKeyExchangeMsg{} m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &finishedMsg{} m.verifyData = randomBytes(12, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &nextProtoMsg{} m.proto = randomString(rand.Intn(255), rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } diff --git a/libgo/go/crypto/tls/handshake_server.go b/libgo/go/crypto/tls/handshake_server.go index 809c8c15e5d..37c8d154ac4 100644 --- a/libgo/go/crypto/tls/handshake_server.go +++ b/libgo/go/crypto/tls/handshake_server.go @@ -103,6 +103,9 @@ FindCipherSuite: hello.nextProtoNeg = true hello.nextProtos = config.NextProtos } + if clientHello.ocspStapling && len(config.Certificates[0].OCSPStaple) > 0 { + hello.ocspStapling = true + } finishedHash.Write(hello.marshal()) c.writeRecord(recordTypeHandshake, hello.marshal()) @@ -116,6 +119,14 @@ FindCipherSuite: finishedHash.Write(certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal()) + if hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.statusType = statusTypeOCSP + certStatus.response = config.Certificates[0].OCSPStaple + finishedHash.Write(certStatus.marshal()) + c.writeRecord(recordTypeHandshake, certStatus.marshal()) + } + keyAgreement := suite.ka() skx, err := keyAgreement.generateServerKeyExchange(config, clientHello, hello) diff --git a/libgo/go/crypto/tls/handshake_server_test.go b/libgo/go/crypto/tls/handshake_server_test.go index 6beb6a9f62b..5a1e754dcf5 100644 --- a/libgo/go/crypto/tls/handshake_server_test.go +++ b/libgo/go/crypto/tls/handshake_server_test.go @@ -188,8 +188,10 @@ var testPrivateKey = &rsa.PrivateKey{ E: 65537, }, D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"), - P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), - Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + Primes: []*big.Int{ + bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), + bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + }, } // Script of interaction with gnutls implementation. diff --git a/libgo/go/crypto/tls/tls.go b/libgo/go/crypto/tls/tls.go index 7de44bbd244..7d0bb9f34b8 100644 --- a/libgo/go/crypto/tls/tls.go +++ b/libgo/go/crypto/tls/tls.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package partially implements the TLS 1.1 protocol, as specified in RFC 4346. +// Package tls partially implements the TLS 1.1 protocol, as specified in RFC +// 4346. package tls import ( diff --git a/libgo/go/crypto/twofish/twofish.go b/libgo/go/crypto/twofish/twofish.go index 62253e79788..9303f03ffd8 100644 --- a/libgo/go/crypto/twofish/twofish.go +++ b/libgo/go/crypto/twofish/twofish.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements Bruce Schneier's Twofish encryption algorithm. +// Package twofish implements Bruce Schneier's Twofish encryption algorithm. package twofish // Twofish is defined in http://www.schneier.com/paper-twofish-paper.pdf [TWOFISH] diff --git a/libgo/go/crypto/x509/cert_pool.go b/libgo/go/crypto/x509/cert_pool.go new file mode 100644 index 00000000000..c295fd97e8d --- /dev/null +++ b/libgo/go/crypto/x509/cert_pool.go @@ -0,0 +1,105 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/pem" + "strings" +) + +// Roots is a set of certificates. +type CertPool struct { + bySubjectKeyId map[string][]int + byName map[string][]int + certs []*Certificate +} + +// NewCertPool returns a new, empty CertPool. +func NewCertPool() *CertPool { + return &CertPool{ + make(map[string][]int), + make(map[string][]int), + nil, + } +} + +func nameToKey(name *Name) string { + return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName +} + +// findVerifiedParents attempts to find certificates in s which have signed the +// given certificate. If no such certificate can be found or the signature +// doesn't match, it returns nil. +func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int) { + var candidates []int + + if len(cert.AuthorityKeyId) > 0 { + candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] + } + if len(candidates) == 0 { + candidates = s.byName[nameToKey(&cert.Issuer)] + } + + for _, c := range candidates { + if cert.CheckSignatureFrom(s.certs[c]) == nil { + parents = append(parents, c) + } + } + + return +} + +// AddCert adds a certificate to a pool. +func (s *CertPool) AddCert(cert *Certificate) { + if cert == nil { + panic("adding nil Certificate to CertPool") + } + + // Check that the certificate isn't being added twice. + for _, c := range s.certs { + if c.Equal(cert) { + return + } + } + + n := len(s.certs) + s.certs = append(s.certs, cert) + + if len(cert.SubjectKeyId) > 0 { + keyId := string(cert.SubjectKeyId) + s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n) + } + name := nameToKey(&cert.Subject) + s.byName[name] = append(s.byName[name], n) +} + +// AppendCertsFromPEM attempts to parse a series of PEM encoded root +// certificates. It appends any certificates found to s and returns true if any +// certificates were successfully parsed. +// +// On many Linux systems, /etc/ssl/cert.pem will contains the system wide set +// of root CAs in a format suitable for this function. +func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := ParseCertificate(block.Bytes) + if err != nil { + continue + } + + s.AddCert(cert) + ok = true + } + + return +} diff --git a/libgo/go/crypto/x509/verify.go b/libgo/go/crypto/x509/verify.go new file mode 100644 index 00000000000..9145880a237 --- /dev/null +++ b/libgo/go/crypto/x509/verify.go @@ -0,0 +1,239 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "os" + "strings" + "time" +) + +type InvalidReason int + +const ( + // NotAuthorizedToSign results when a certificate is signed by another + // which isn't marked as a CA certificate. + NotAuthorizedToSign InvalidReason = iota + // Expired results when a certificate has expired, based on the time + // given in the VerifyOptions. + Expired + // CANotAuthorizedForThisName results when an intermediate or root + // certificate has a name constraint which doesn't include the name + // being checked. + CANotAuthorizedForThisName +) + +// CertificateInvalidError results when an odd error occurs. Users of this +// library probably want to handle all these errors uniformly. +type CertificateInvalidError struct { + Cert *Certificate + Reason InvalidReason +} + +func (e CertificateInvalidError) String() string { + switch e.Reason { + case NotAuthorizedToSign: + return "x509: certificate is not authorized to sign other other certificates" + case Expired: + return "x509: certificate has expired or is not yet valid" + case CANotAuthorizedForThisName: + return "x509: a root or intermediate certificate is not authorized to sign in this domain" + } + return "x509: unknown error" +} + +// HostnameError results when the set of authorized names doesn't match the +// requested name. +type HostnameError struct { + Certificate *Certificate + Host string +} + +func (h HostnameError) String() string { + var valid string + c := h.Certificate + if len(c.DNSNames) > 0 { + valid = strings.Join(c.DNSNames, ", ") + } else { + valid = c.Subject.CommonName + } + return "certificate is valid for " + valid + ", not " + h.Host +} + + +// UnknownAuthorityError results when the certificate issuer is unknown +type UnknownAuthorityError struct { + cert *Certificate +} + +func (e UnknownAuthorityError) String() string { + return "x509: certificate signed by unknown authority" +} + +// VerifyOptions contains parameters for Certificate.Verify. It's a structure +// because other PKIX verification APIs have ended up needing many options. +type VerifyOptions struct { + DNSName string + Intermediates *CertPool + Roots *CertPool + CurrentTime int64 // if 0, the current system time is used. +} + +const ( + leafCertificate = iota + intermediateCertificate + rootCertificate +) + +// isValid performs validity checks on the c. +func (c *Certificate) isValid(certType int, opts *VerifyOptions) os.Error { + if opts.CurrentTime < c.NotBefore.Seconds() || + opts.CurrentTime > c.NotAfter.Seconds() { + return CertificateInvalidError{c, Expired} + } + + if len(c.PermittedDNSDomains) > 0 { + for _, domain := range c.PermittedDNSDomains { + if opts.DNSName == domain || + (strings.HasSuffix(opts.DNSName, domain) && + len(opts.DNSName) >= 1+len(domain) && + opts.DNSName[len(opts.DNSName)-len(domain)-1] == '.') { + continue + } + + return CertificateInvalidError{c, CANotAuthorizedForThisName} + } + } + + // KeyUsage status flags are ignored. From Engineering Security, Peter + // Gutmann: A European government CA marked its signing certificates as + // being valid for encryption only, but no-one noticed. Another + // European CA marked its signature keys as not being valid for + // signatures. A different CA marked its own trusted root certificate + // as being invalid for certificate signing. Another national CA + // distributed a certificate to be used to encrypt data for the + // country’s tax authority that was marked as only being usable for + // digital signatures but not for encryption. Yet another CA reversed + // the order of the bit flags in the keyUsage due to confusion over + // encoding endianness, essentially setting a random keyUsage in + // certificates that it issued. Another CA created a self-invalidating + // certificate by adding a certificate policy statement stipulating + // that the certificate had to be used strictly as specified in the + // keyUsage, and a keyUsage containing a flag indicating that the RSA + // encryption key could only be used for Diffie-Hellman key agreement. + + if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) { + return CertificateInvalidError{c, NotAuthorizedToSign} + } + + return nil +} + +// Verify attempts to verify c by building one or more chains from c to a +// certificate in opts.roots, using certificates in opts.Intermediates if +// needed. If successful, it returns one or chains where the first element of +// the chain is c and the last element is from opts.Roots. +// +// WARNING: this doesn't do any revocation checking. +func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err os.Error) { + if opts.CurrentTime == 0 { + opts.CurrentTime = time.Seconds() + } + err = c.isValid(leafCertificate, &opts) + if err != nil { + return + } + if len(opts.DNSName) > 0 { + err = c.VerifyHostname(opts.DNSName) + if err != nil { + return + } + } + return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts) +} + +func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate { + n := make([]*Certificate, len(chain)+1) + copy(n, chain) + n[len(chain)] = cert + return n +} + +func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err os.Error) { + for _, rootNum := range opts.Roots.findVerifiedParents(c) { + root := opts.Roots.certs[rootNum] + err = root.isValid(rootCertificate, opts) + if err != nil { + continue + } + chains = append(chains, appendToFreshChain(currentChain, root)) + } + + for _, intermediateNum := range opts.Intermediates.findVerifiedParents(c) { + intermediate := opts.Intermediates.certs[intermediateNum] + err = intermediate.isValid(intermediateCertificate, opts) + if err != nil { + continue + } + var childChains [][]*Certificate + childChains, ok := cache[intermediateNum] + if !ok { + childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts) + cache[intermediateNum] = childChains + } + chains = append(chains, childChains...) + } + + if len(chains) > 0 { + err = nil + } + + if len(chains) == 0 && err == nil { + err = UnknownAuthorityError{c} + } + + return +} + +func matchHostnames(pattern, host string) bool { + if len(pattern) == 0 || len(host) == 0 { + return false + } + + patternParts := strings.Split(pattern, ".", -1) + hostParts := strings.Split(host, ".", -1) + + if len(patternParts) != len(hostParts) { + return false + } + + for i, patternPart := range patternParts { + if patternPart == "*" { + continue + } + if patternPart != hostParts[i] { + return false + } + } + + return true +} + +// VerifyHostname returns nil if c is a valid certificate for the named host. +// Otherwise it returns an os.Error describing the mismatch. +func (c *Certificate) VerifyHostname(h string) os.Error { + if len(c.DNSNames) > 0 { + for _, match := range c.DNSNames { + if matchHostnames(match, h) { + return nil + } + } + // If Subject Alt Name is given, we ignore the common name. + } else if matchHostnames(c.Subject.CommonName, h) { + return nil + } + + return HostnameError{c, h} +} diff --git a/libgo/go/crypto/x509/verify_test.go b/libgo/go/crypto/x509/verify_test.go new file mode 100644 index 00000000000..6a103dcfba7 --- /dev/null +++ b/libgo/go/crypto/x509/verify_test.go @@ -0,0 +1,390 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/pem" + "os" + "strings" + "testing" +) + +type verifyTest struct { + leaf string + intermediates []string + roots []string + currentTime int64 + dnsName string + + errorCallback func(*testing.T, int, os.Error) bool + expectedChains [][]string +} + +var verifyTests = []verifyTest{ + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.example.com", + + errorCallback: expectHostnameError, + }, + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1, + dnsName: "www.example.com", + + errorCallback: expectExpired, + }, + { + leaf: googleLeaf, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + errorCallback: expectAuthorityUnknown, + }, + { + leaf: googleLeaf, + intermediates: []string{verisignRoot, thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: googleLeaf, + intermediates: []string{verisignRoot, thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: dnssecExpLeaf, + intermediates: []string{startComIntermediate}, + roots: []string{startComRoot}, + currentTime: 1302726541, + + expectedChains: [][]string{ + []string{"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority"}, + }, + }, +} + +func expectHostnameError(t *testing.T, i int, err os.Error) (ok bool) { + if _, ok := err.(HostnameError); !ok { + t.Errorf("#%d: error was not a HostnameError: %s", i, err) + return false + } + return true +} + +func expectExpired(t *testing.T, i int, err os.Error) (ok bool) { + if inval, ok := err.(CertificateInvalidError); !ok || inval.Reason != Expired { + t.Errorf("#%d: error was not Expired: %s", i, err) + return false + } + return true +} + +func expectAuthorityUnknown(t *testing.T, i int, err os.Error) (ok bool) { + if _, ok := err.(UnknownAuthorityError); !ok { + t.Errorf("#%d: error was not UnknownAuthorityError: %s", i, err) + return false + } + return true +} + +func certificateFromPEM(pemBytes string) (*Certificate, os.Error) { + block, _ := pem.Decode([]byte(pemBytes)) + if block == nil { + return nil, os.ErrorString("failed to decode PEM") + } + return ParseCertificate(block.Bytes) +} + +func TestVerify(t *testing.T) { + for i, test := range verifyTests { + opts := VerifyOptions{ + Roots: NewCertPool(), + Intermediates: NewCertPool(), + DNSName: test.dnsName, + CurrentTime: test.currentTime, + } + + for j, root := range test.roots { + ok := opts.Roots.AppendCertsFromPEM([]byte(root)) + if !ok { + t.Errorf("#%d: failed to parse root #%d", i, j) + return + } + } + + for j, intermediate := range test.intermediates { + ok := opts.Intermediates.AppendCertsFromPEM([]byte(intermediate)) + if !ok { + t.Errorf("#%d: failed to parse intermediate #%d", i, j) + return + } + } + + leaf, err := certificateFromPEM(test.leaf) + if err != nil { + t.Errorf("#%d: failed to parse leaf: %s", i, err) + return + } + + chains, err := leaf.Verify(opts) + + if test.errorCallback == nil && err != nil { + t.Errorf("#%d: unexpected error: %s", i, err) + } + if test.errorCallback != nil { + if !test.errorCallback(t, i, err) { + return + } + } + + if len(chains) != len(test.expectedChains) { + t.Errorf("#%d: wanted %d chains, got %d", i, len(test.expectedChains), len(chains)) + } + + // We check that each returned chain matches a chain from + // expectedChains but an entry in expectedChains can't match + // two chains. + seenChains := make([]bool, len(chains)) + NextOutputChain: + for _, chain := range chains { + TryNextExpected: + for j, expectedChain := range test.expectedChains { + if seenChains[j] { + continue + } + if len(chain) != len(expectedChain) { + continue + } + for k, cert := range chain { + if strings.Index(nameToKey(&cert.Subject), expectedChain[k]) == -1 { + continue TryNextExpected + } + } + // we matched + seenChains[j] = true + continue NextOutputChain + } + t.Errorf("#%d: No expected chain matched %s", i, chainToDebugString(chain)) + } + } +} + +func chainToDebugString(chain []*Certificate) string { + var chainStr string + for _, cert := range chain { + if len(chainStr) > 0 { + chainStr += " -> " + } + chainStr += nameToKey(&cert.Subject) + } + return chainStr +} + +const verisignRoot = `-----BEGIN CERTIFICATE----- +MIICPDCCAaUCEHC65B0Q2Sk0tjjKewPMur8wDQYJKoZIhvcNAQECBQAwXzELMAkG +A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz +cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2 +MDEyOTAwMDAwMFoXDTI4MDgwMTIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV +BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt +YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN +ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE +BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is +I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G +CSqGSIb3DQEBAgUAA4GBALtMEivPLCYATxQT3ab7/AoRhIzzKBxnki98tsX63/Do +lbwdj2wsqFHMc9ikwFPwTtYmwHYBV4GSXiHx0bH/59AhWM1pF+NEHJwZRDmJXNyc +AA9WjQKZ7aKQRUzkuxCkPfAyAw7xzvjoyVGM5mKf5p/AfbdynMk2OmufTqj/ZA1k +-----END CERTIFICATE----- +` + +const thawteIntermediate = `-----BEGIN CERTIFICATE----- +MIIDIzCCAoygAwIBAgIEMAAAAjANBgkqhkiG9w0BAQUFADBfMQswCQYDVQQGEwJV +UzEXMBUGA1UEChMOVmVyaVNpZ24sIEluYy4xNzA1BgNVBAsTLkNsYXNzIDMgUHVi +bGljIFByaW1hcnkgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwHhcNMDQwNTEzMDAw +MDAwWhcNMTQwNTEyMjM1OTU5WjBMMQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhh +d3RlIENvbnN1bHRpbmcgKFB0eSkgTHRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBD +QTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1NNn0I0Vf67NMf59HZGhPwtx +PKzMyGT7Y/wySweUvW+Aui/hBJPAM/wJMyPpC3QrccQDxtLN4i/1CWPN/0ilAL/g +5/OIty0y3pg25gqtAHvEZEo7hHUD8nCSfQ5i9SGraTaEMXWQ+L/HbIgbBpV8yeWo +3nWhLHpo39XKHIdYYBkCAwEAAaOB/jCB+zASBgNVHRMBAf8ECDAGAQH/AgEAMAsG +A1UdDwQEAwIBBjARBglghkgBhvhCAQEEBAMCAQYwKAYDVR0RBCEwH6QdMBsxGTAX +BgNVBAMTEFByaXZhdGVMYWJlbDMtMTUwMQYDVR0fBCowKDAmoCSgIoYgaHR0cDov +L2NybC52ZXJpc2lnbi5jb20vcGNhMy5jcmwwMgYIKwYBBQUHAQEEJjAkMCIGCCsG +AQUFBzABhhZodHRwOi8vb2NzcC50aGF3dGUuY29tMDQGA1UdJQQtMCsGCCsGAQUF +BwMBBggrBgEFBQcDAgYJYIZIAYb4QgQBBgpghkgBhvhFAQgBMA0GCSqGSIb3DQEB +BQUAA4GBAFWsY+reod3SkF+fC852vhNRj5PZBSvIG3dLrWlQoe7e3P3bB+noOZTc +q3J5Lwa/q4FwxKjt6lM07e8eU9kGx1Yr0Vz00YqOtCuxN5BICEIlxT6Ky3/rbwTR +bcV0oveifHtgPHfNDs5IAn8BL7abN+AqKjbc1YXWrOU/VG+WHgWv +-----END CERTIFICATE----- +` + +const googleLeaf = `-----BEGIN CERTIFICATE----- +MIIDITCCAoqgAwIBAgIQL9+89q6RUm0PmqPfQDQ+mjANBgkqhkiG9w0BAQUFADBM +MQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhhd3RlIENvbnN1bHRpbmcgKFB0eSkg +THRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBDQTAeFw0wOTEyMTgwMDAwMDBaFw0x +MTEyMTgyMzU5NTlaMGgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh +MRYwFAYDVQQHFA1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKFApHb29nbGUgSW5jMRcw +FQYDVQQDFA53d3cuZ29vZ2xlLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC +gYEA6PmGD5D6htffvXImttdEAoN4c9kCKO+IRTn7EOh8rqk41XXGOOsKFQebg+jN +gtXj9xVoRaELGYW84u+E593y17iYwqG7tcFR39SDAqc9BkJb4SLD3muFXxzW2k6L +05vuuWciKh0R73mkszeK9P4Y/bz5RiNQl/Os/CRGK1w7t0UCAwEAAaOB5zCB5DAM +BgNVHRMBAf8EAjAAMDYGA1UdHwQvMC0wK6ApoCeGJWh0dHA6Ly9jcmwudGhhd3Rl +LmNvbS9UaGF3dGVTR0NDQS5jcmwwKAYDVR0lBCEwHwYIKwYBBQUHAwEGCCsGAQUF +BwMCBglghkgBhvhCBAEwcgYIKwYBBQUHAQEEZjBkMCIGCCsGAQUFBzABhhZodHRw +Oi8vb2NzcC50aGF3dGUuY29tMD4GCCsGAQUFBzAChjJodHRwOi8vd3d3LnRoYXd0 +ZS5jb20vcmVwb3NpdG9yeS9UaGF3dGVfU0dDX0NBLmNydDANBgkqhkiG9w0BAQUF +AAOBgQCfQ89bxFApsb/isJr/aiEdLRLDLE5a+RLizrmCUi3nHX4adpaQedEkUjh5 +u2ONgJd8IyAPkU0Wueru9G2Jysa9zCRo1kNbzipYvzwY4OA8Ys+WAi0oR1A04Se6 +z5nRUP8pJcA2NhUzUnC+MY+f6H/nEQyNv4SgQhqAibAxWEEHXw== +-----END CERTIFICATE-----` + +const dnssecExpLeaf = `-----BEGIN CERTIFICATE----- +MIIGzTCCBbWgAwIBAgIDAdD6MA0GCSqGSIb3DQEBBQUAMIGMMQswCQYDVQQGEwJJ +TDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0 +YWwgQ2VydGlmaWNhdGUgU2lnbmluZzE4MDYGA1UEAxMvU3RhcnRDb20gQ2xhc3Mg +MSBQcmltYXJ5IEludGVybWVkaWF0ZSBTZXJ2ZXIgQ0EwHhcNMTAwNzA0MTQ1MjQ1 +WhcNMTEwNzA1MTA1NzA0WjCBwTEgMB4GA1UEDRMXMjIxMTM3LWxpOWE5dHhJRzZM +NnNyVFMxCzAJBgNVBAYTAlVTMR4wHAYDVQQKExVQZXJzb25hIE5vdCBWYWxpZGF0 +ZWQxKTAnBgNVBAsTIFN0YXJ0Q29tIEZyZWUgQ2VydGlmaWNhdGUgTWVtYmVyMRsw +GQYDVQQDExJ3d3cuZG5zc2VjLWV4cC5vcmcxKDAmBgkqhkiG9w0BCQEWGWhvc3Rt +YXN0ZXJAZG5zc2VjLWV4cC5vcmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQDEdF/22vaxrPbqpgVYMWi+alfpzBctpbfLBdPGuqOazJdCT0NbWcK8/+B4 +X6OlSOURNIlwLzhkmwVsWdVv6dVSaN7d4yI/fJkvgfDB9+au+iBJb6Pcz8ULBfe6 +D8HVvqKdORp6INzHz71z0sghxrQ0EAEkoWAZLh+kcn2ZHdcmZaBNUfjmGbyU6PRt +RjdqoP+owIaC1aktBN7zl4uO7cRjlYFdusINrh2kPP02KAx2W84xjxX1uyj6oS6e +7eBfvcwe8czW/N1rbE0CoR7h9+HnIrjnVG9RhBiZEiw3mUmF++Up26+4KTdRKbu3 ++BL4yMpfd66z0+zzqu+HkvyLpFn5AgMBAAGjggL/MIIC+zAJBgNVHRMEAjAAMAsG +A1UdDwQEAwIDqDATBgNVHSUEDDAKBggrBgEFBQcDATAdBgNVHQ4EFgQUy04I5guM +drzfh2JQaXhgV86+4jUwHwYDVR0jBBgwFoAU60I00Jiwq5/0G2sI98xkLu8OLEUw +LQYDVR0RBCYwJIISd3d3LmRuc3NlYy1leHAub3Jngg5kbnNzZWMtZXhwLm9yZzCC +AUIGA1UdIASCATkwggE1MIIBMQYLKwYBBAGBtTcBAgIwggEgMC4GCCsGAQUFBwIB +FiJodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9wb2xpY3kucGRmMDQGCCsGAQUFBwIB +FihodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9pbnRlcm1lZGlhdGUucGRmMIG3Bggr +BgEFBQcCAjCBqjAUFg1TdGFydENvbSBMdGQuMAMCAQEagZFMaW1pdGVkIExpYWJp +bGl0eSwgc2VlIHNlY3Rpb24gKkxlZ2FsIExpbWl0YXRpb25zKiBvZiB0aGUgU3Rh +cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgUG9saWN5IGF2YWlsYWJsZSBh +dCBodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9wb2xpY3kucGRmMGEGA1UdHwRaMFgw +KqAooCaGJGh0dHA6Ly93d3cuc3RhcnRzc2wuY29tL2NydDEtY3JsLmNybDAqoCig +JoYkaHR0cDovL2NybC5zdGFydHNzbC5jb20vY3J0MS1jcmwuY3JsMIGOBggrBgEF +BQcBAQSBgTB/MDkGCCsGAQUFBzABhi1odHRwOi8vb2NzcC5zdGFydHNzbC5jb20v +c3ViL2NsYXNzMS9zZXJ2ZXIvY2EwQgYIKwYBBQUHMAKGNmh0dHA6Ly93d3cuc3Rh +cnRzc2wuY29tL2NlcnRzL3N1Yi5jbGFzczEuc2VydmVyLmNhLmNydDAjBgNVHRIE +HDAahhhodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS8wDQYJKoZIhvcNAQEFBQADggEB +ACXj6SB59KRJPenn6gUdGEqcta97U769SATyiQ87i9er64qLwvIGLMa3o2Rcgl2Y +kghUeyLdN/EXyFBYA8L8uvZREPoc7EZukpT/ZDLXy9i2S0jkOxvF2fD/XLbcjGjM +iEYG1/6ASw0ri9C0k4oDDoJLCoeH9++yqF7SFCCMcDkJqiAGXNb4euDpa8vCCtEQ +CSS+ObZbfkreRt3cNCf5LfCXe9OsTnCfc8Cuq81c0oLaG+SmaLUQNBuToq8e9/Zm ++b+/a3RVjxmkV5OCcGVBxsXNDn54Q6wsdw0TBMcjwoEndzpLS7yWgFbbkq5ZiGpw +Qibb2+CfKuQ+WFV1GkVQmVA= +-----END CERTIFICATE-----` + +const startComIntermediate = `-----BEGIN CERTIFICATE----- +MIIGNDCCBBygAwIBAgIBGDANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW +MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg +Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkwHhcNMDcxMDI0MjA1NDE3WhcNMTcxMDI0MjA1NDE3WjCB +jDELMAkGA1UEBhMCSUwxFjAUBgNVBAoTDVN0YXJ0Q29tIEx0ZC4xKzApBgNVBAsT +IlNlY3VyZSBEaWdpdGFsIENlcnRpZmljYXRlIFNpZ25pbmcxODA2BgNVBAMTL1N0 +YXJ0Q29tIENsYXNzIDEgUHJpbWFyeSBJbnRlcm1lZGlhdGUgU2VydmVyIENBMIIB +IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtonGrO8JUngHrJJj0PREGBiE +gFYfka7hh/oyULTTRwbw5gdfcA4Q9x3AzhA2NIVaD5Ksg8asWFI/ujjo/OenJOJA +pgh2wJJuniptTT9uYSAK21ne0n1jsz5G/vohURjXzTCm7QduO3CHtPn66+6CPAVv +kvek3AowHpNz/gfK11+AnSJYUq4G2ouHI2mw5CrY6oPSvfNx23BaKA+vWjhwRRI/ +ME3NO68X5Q/LoKldSKqxYVDLNM08XMML6BDAjJvwAwNi/rJsPnIO7hxDKslIDlc5 +xDEhyBDBLIf+VJVSH1I8MRKbf+fAoKVZ1eKPPvDVqOHXcDGpxLPPr21TLwb0pwID +AQABo4IBrTCCAakwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMCAQYwHQYD +VR0OBBYEFOtCNNCYsKuf9BtrCPfMZC7vDixFMB8GA1UdIwQYMBaAFE4L7xqkQFul +F2mHMMo0aEPQQa7yMGYGCCsGAQUFBwEBBFowWDAnBggrBgEFBQcwAYYbaHR0cDov +L29jc3Auc3RhcnRzc2wuY29tL2NhMC0GCCsGAQUFBzAChiFodHRwOi8vd3d3LnN0 +YXJ0c3NsLmNvbS9zZnNjYS5jcnQwWwYDVR0fBFQwUjAnoCWgI4YhaHR0cDovL3d3 +dy5zdGFydHNzbC5jb20vc2ZzY2EuY3JsMCegJaAjhiFodHRwOi8vY3JsLnN0YXJ0 +c3NsLmNvbS9zZnNjYS5jcmwwgYAGA1UdIAR5MHcwdQYLKwYBBAGBtTcBAgEwZjAu +BggrBgEFBQcCARYiaHR0cDovL3d3dy5zdGFydHNzbC5jb20vcG9saWN5LnBkZjA0 +BggrBgEFBQcCARYoaHR0cDovL3d3dy5zdGFydHNzbC5jb20vaW50ZXJtZWRpYXRl +LnBkZjANBgkqhkiG9w0BAQUFAAOCAgEAIQlJPqWIbuALi0jaMU2P91ZXouHTYlfp +tVbzhUV1O+VQHwSL5qBaPucAroXQ+/8gA2TLrQLhxpFy+KNN1t7ozD+hiqLjfDen +xk+PNdb01m4Ge90h2c9W/8swIkn+iQTzheWq8ecf6HWQTd35RvdCNPdFWAwRDYSw +xtpdPvkBnufh2lWVvnQce/xNFE+sflVHfXv0pQ1JHpXo9xLBzP92piVH0PN1Nb6X +t1gW66pceG/sUzCv6gRNzKkC4/C2BBL2MLERPZBOVmTX3DxDX3M570uvh+v2/miI +RHLq0gfGabDBoYvvF0nXYbFFSF87ICHpW7LM9NfpMfULFWE7epTj69m8f5SuauNi +YpaoZHy4h/OZMn6SolK+u/hlz8nyMPyLwcKmltdfieFcNID1j0cHL7SRv7Gifl9L +WtBbnySGBVFaaQNlQ0lxxeBvlDRr9hvYqbBMflPrj0jfyjO1SPo2ShpTpjMM0InN +SRXNiTE8kMBy12VLUjWKRhFEuT2OKGWmPnmeXAhEKa2wNREuIU640ucQPl2Eg7PD +wuTSxv0JS3QJ3fGz0xk+gA2iCxnwOOfFwq/iI9th4p1cbiCJSS4jarJiwUW0n6+L +p/EiO/h94pDQehn7Skzj0n1fSoMD7SfWI55rjbRZotnvbIIp3XUZPD9MEI3vu3Un +0q6Dp6jOW6c= +-----END CERTIFICATE-----` + +const startComRoot = `-----BEGIN CERTIFICATE----- +MIIHyTCCBbGgAwIBAgIBATANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW +MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg +Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkwHhcNMDYwOTE3MTk0NjM2WhcNMzYwOTE3MTk0NjM2WjB9 +MQswCQYDVQQGEwJJTDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMi +U2VjdXJlIERpZ2l0YWwgQ2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3Rh +cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwggIiMA0GCSqGSIb3DQEBAQUA +A4ICDwAwggIKAoICAQDBiNsJvGxGfHiflXu1M5DycmLWwTYgIiRezul38kMKogZk +pMyONvg45iPwbm2xPN1yo4UcodM9tDMr0y+v/uqwQVlntsQGfQqedIXWeUyAN3rf +OQVSWff0G0ZDpNKFhdLDcfN1YjS6LIp/Ho/u7TTQEceWzVI9ujPW3U3eCztKS5/C +Ji/6tRYccjV3yjxd5srhJosaNnZcAdt0FCX+7bWgiA/deMotHweXMAEtcnn6RtYT +Kqi5pquDSR3l8u/d5AGOGAqPY1MWhWKpDhk6zLVmpsJrdAfkK+F2PrRt2PZE4XNi +HzvEvqBTViVsUQn3qqvKv3b9bZvzndu/PWa8DFaqr5hIlTpL36dYUNk4dalb6kMM +Av+Z6+hsTXBbKWWc3apdzK8BMewM69KN6Oqce+Zu9ydmDBpI125C4z/eIT574Q1w ++2OqqGwaVLRcJXrJosmLFqa7LH4XXgVNWG4SHQHuEhANxjJ/GP/89PrNbpHoNkm+ +Gkhpi8KWTRoSsmkXwQqQ1vp5Iki/untp+HDH+no32NgN0nZPV/+Qt+OR0t3vwmC3 +Zzrd/qqc8NSLf3Iizsafl7b4r4qgEKjZ+xjGtrVcUjyJthkqcwEKDwOzEmDyei+B +26Nu/yYwl/WL3YlXtq09s68rxbd2AvCl1iuahhQqcvbjM4xdCUsT37uMdBNSSwID +AQABo4ICUjCCAk4wDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAa4wHQYDVR0OBBYE +FE4L7xqkQFulF2mHMMo0aEPQQa7yMGQGA1UdHwRdMFswLKAqoCiGJmh0dHA6Ly9j +ZXJ0LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMCugKaAnhiVodHRwOi8vY3Js +LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMIIBXQYDVR0gBIIBVDCCAVAwggFM +BgsrBgEEAYG1NwEBATCCATswLwYIKwYBBQUHAgEWI2h0dHA6Ly9jZXJ0LnN0YXJ0 +Y29tLm9yZy9wb2xpY3kucGRmMDUGCCsGAQUFBwIBFilodHRwOi8vY2VydC5zdGFy +dGNvbS5vcmcvaW50ZXJtZWRpYXRlLnBkZjCB0AYIKwYBBQUHAgIwgcMwJxYgU3Rh +cnQgQ29tbWVyY2lhbCAoU3RhcnRDb20pIEx0ZC4wAwIBARqBl0xpbWl0ZWQgTGlh +YmlsaXR5LCByZWFkIHRoZSBzZWN0aW9uICpMZWdhbCBMaW1pdGF0aW9ucyogb2Yg +dGhlIFN0YXJ0Q29tIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFBvbGljeSBhdmFp +bGFibGUgYXQgaHR0cDovL2NlcnQuc3RhcnRjb20ub3JnL3BvbGljeS5wZGYwEQYJ +YIZIAYb4QgEBBAQDAgAHMDgGCWCGSAGG+EIBDQQrFilTdGFydENvbSBGcmVlIFNT +TCBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTANBgkqhkiG9w0BAQUFAAOCAgEAFmyZ +9GYMNPXQhV59CuzaEE44HF7fpiUFS5Eyweg78T3dRAlbB0mKKctmArexmvclmAk8 +jhvh3TaHK0u7aNM5Zj2gJsfyOZEdUauCe37Vzlrk4gNXcGmXCPleWKYK34wGmkUW +FjgKXlf2Ysd6AgXmvB618p70qSmD+LIU424oh0TDkBreOKk8rENNZEXO3SipXPJz +ewT4F+irsfMuXGRuczE6Eri8sxHkfY+BUZo7jYn0TZNmezwD7dOaHZrzZVD1oNB1 +ny+v8OqCQ5j4aZyJecRDjkZy42Q2Eq/3JR44iZB3fsNrarnDy0RLrHiQi+fHLB5L +EUTINFInzQpdn4XBidUaePKVEFMy3YCEZnXZtWgo+2EuvoSoOMCZEoalHmdkrQYu +L6lwhceWD3yJZfWOQ1QOq92lgDmUYMA0yZZwLKMS9R9Ie70cfmu3nZD0Ijuu+Pwq +yvqCUqDvr0tVk+vBtfAii6w0TiYiBKGHLHVKt+V9E9e4DGTANtLJL4YSjCMJwRuC +O3NJo2pXh5Tl1njFmUNj403gdy3hZZlyaQQaRwnmDwFWJPsfvw55qVguucQJAX6V +um0ABj6y6koQOdjQK/W/7HW/lwLFCRsI3FU34oH7N4RDYiDK51ZLZer+bMEkkySh +NOsF/5oirpt9P/FlUQqmMGqz9IgcgA38corog14= +-----END CERTIFICATE-----` diff --git a/libgo/go/crypto/x509/x509.go b/libgo/go/crypto/x509/x509.go index 6825030d6f9..d0c5a26a9a8 100644 --- a/libgo/go/crypto/x509/x509.go +++ b/libgo/go/crypto/x509/x509.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package parses X.509-encoded keys and certificates. +// Package x509 parses X.509-encoded keys and certificates. package x509 import ( "asn1" "big" + "bytes" "container/vector" "crypto" "crypto/rsa" @@ -15,7 +16,6 @@ import ( "hash" "io" "os" - "strings" "time" ) @@ -27,6 +27,20 @@ type pkcs1PrivateKey struct { D asn1.RawValue P asn1.RawValue Q asn1.RawValue + // We ignore these values, if present, because rsa will calculate them. + Dp asn1.RawValue "optional" + Dq asn1.RawValue "optional" + Qinv asn1.RawValue "optional" + + AdditionalPrimes []pkcs1AddtionalRSAPrime "optional" +} + +type pkcs1AddtionalRSAPrime struct { + Prime asn1.RawValue + + // We ignore these values because rsa will calculate them. + Exp asn1.RawValue + Coeff asn1.RawValue } // rawValueIsInteger returns true iff the given ASN.1 RawValue is an INTEGER type. @@ -46,6 +60,10 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { return } + if priv.Version > 1 { + return nil, os.ErrorString("x509: unsupported private key version") + } + if !rawValueIsInteger(&priv.N) || !rawValueIsInteger(&priv.D) || !rawValueIsInteger(&priv.P) || @@ -61,26 +79,66 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { } key.D = new(big.Int).SetBytes(priv.D.Bytes) - key.P = new(big.Int).SetBytes(priv.P.Bytes) - key.Q = new(big.Int).SetBytes(priv.Q.Bytes) + key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes)) + key.Primes[0] = new(big.Int).SetBytes(priv.P.Bytes) + key.Primes[1] = new(big.Int).SetBytes(priv.Q.Bytes) + for i, a := range priv.AdditionalPrimes { + if !rawValueIsInteger(&a.Prime) { + return nil, asn1.StructuralError{"tags don't match"} + } + key.Primes[i+2] = new(big.Int).SetBytes(a.Prime.Bytes) + // We ignore the other two values because rsa will calculate + // them as needed. + } err = key.Validate() if err != nil { return nil, err } + key.Precompute() return } +// rawValueForBig returns an asn1.RawValue which represents the given integer. +func rawValueForBig(n *big.Int) asn1.RawValue { + b := n.Bytes() + if n.Sign() >= 0 && len(b) > 0 && b[0]&0x80 != 0 { + // This positive number would be interpreted as a negative + // number in ASN.1 because the MSB is set. + padded := make([]byte, len(b)+1) + copy(padded[1:], b) + b = padded + } + return asn1.RawValue{Tag: 2, Bytes: b} +} + // MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form. func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { + key.Precompute() + + version := 0 + if len(key.Primes) > 2 { + version = 1 + } + priv := pkcs1PrivateKey{ - Version: 1, - N: asn1.RawValue{Tag: 2, Bytes: key.PublicKey.N.Bytes()}, + Version: version, + N: rawValueForBig(key.N), E: key.PublicKey.E, - D: asn1.RawValue{Tag: 2, Bytes: key.D.Bytes()}, - P: asn1.RawValue{Tag: 2, Bytes: key.P.Bytes()}, - Q: asn1.RawValue{Tag: 2, Bytes: key.Q.Bytes()}, + D: rawValueForBig(key.D), + P: rawValueForBig(key.Primes[0]), + Q: rawValueForBig(key.Primes[1]), + Dp: rawValueForBig(key.Precomputed.Dp), + Dq: rawValueForBig(key.Precomputed.Dq), + Qinv: rawValueForBig(key.Precomputed.Qinv), + } + + priv.AdditionalPrimes = make([]pkcs1AddtionalRSAPrime, len(key.Precomputed.CRTValues)) + for i, values := range key.Precomputed.CRTValues { + priv.AdditionalPrimes[i].Prime = rawValueForBig(key.Primes[2+i]) + priv.AdditionalPrimes[i].Exp = rawValueForBig(values.Exp) + priv.AdditionalPrimes[i].Coeff = rawValueForBig(values.Coeff) } b, _ := asn1.Marshal(priv) @@ -90,6 +148,7 @@ func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { // These structures reflect the ASN.1 structure of X.509 certificates.: type certificate struct { + Raw asn1.RawContent TBSCertificate tbsCertificate SignatureAlgorithm algorithmIdentifier SignatureValue asn1.BitString @@ -127,6 +186,7 @@ type validity struct { } type publicKeyInfo struct { + Raw asn1.RawContent Algorithm algorithmIdentifier PublicKey asn1.BitString } @@ -343,7 +403,10 @@ const ( // A Certificate represents an X.509 certificate. type Certificate struct { - Raw []byte // Raw ASN.1 DER contents. + Raw []byte // Complete ASN.1 DER content (certificate, signature algorithm and signature). + RawTBSCertificate []byte // Certificate part of raw ASN.1 DER content. + RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo. + Signature []byte SignatureAlgorithm SignatureAlgorithm @@ -395,6 +458,10 @@ func (ConstraintViolationError) String() string { return "invalid signature: parent certificate cannot sign this kind of certificate" } +func (c *Certificate) Equal(other *Certificate) bool { + return bytes.Equal(c.Raw, other.Raw) +} + // CheckSignatureFrom verifies that the signature on c is a valid signature // from parent. func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { @@ -434,69 +501,12 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { return UnsupportedAlgorithmError{} } - h.Write(c.Raw) + h.Write(c.RawTBSCertificate) digest := h.Sum() return rsa.VerifyPKCS1v15(pub, hashType, digest, c.Signature) } -func matchHostnames(pattern, host string) bool { - if len(pattern) == 0 || len(host) == 0 { - return false - } - - patternParts := strings.Split(pattern, ".", -1) - hostParts := strings.Split(host, ".", -1) - - if len(patternParts) != len(hostParts) { - return false - } - - for i, patternPart := range patternParts { - if patternPart == "*" { - continue - } - if patternPart != hostParts[i] { - return false - } - } - - return true -} - -type HostnameError struct { - Certificate *Certificate - Host string -} - -func (h *HostnameError) String() string { - var valid string - c := h.Certificate - if len(c.DNSNames) > 0 { - valid = strings.Join(c.DNSNames, ", ") - } else { - valid = c.Subject.CommonName - } - return "certificate is valid for " + valid + ", not " + h.Host -} - -// VerifyHostname returns nil if c is a valid certificate for the named host. -// Otherwise it returns an os.Error describing the mismatch. -func (c *Certificate) VerifyHostname(h string) os.Error { - if len(c.DNSNames) > 0 { - for _, match := range c.DNSNames { - if matchHostnames(match, h) { - return nil - } - } - // If Subject Alt Name is given, we ignore the common name. - } else if matchHostnames(c.Subject.CommonName, h) { - return nil - } - - return &HostnameError{c, h} -} - type UnhandledCriticalExtension struct{} func (h UnhandledCriticalExtension) String() string { @@ -558,7 +568,9 @@ func parsePublicKey(algo PublicKeyAlgorithm, asn1Data []byte) (interface{}, os.E func parseCertificate(in *certificate) (*Certificate, os.Error) { out := new(Certificate) - out.Raw = in.TBSCertificate.Raw + out.Raw = in.Raw + out.RawTBSCertificate = in.TBSCertificate.Raw + out.RawSubjectPublicKeyInfo = in.TBSCertificate.PublicKey.Raw out.Signature = in.SignatureValue.RightAlign() out.SignatureAlgorithm = @@ -975,7 +987,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P Issuer: parent.Subject.toRDNSequence(), Validity: validity{template.NotBefore, template.NotAfter}, Subject: template.Subject.toRDNSequence(), - PublicKey: publicKeyInfo{algorithmIdentifier{oidRSA}, encodedPublicKey}, + PublicKey: publicKeyInfo{nil, algorithmIdentifier{oidRSA}, encodedPublicKey}, Extensions: extensions, } @@ -996,6 +1008,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P } cert, err = asn1.Marshal(certificate{ + nil, c, algorithmIdentifier{oidSHA1WithRSA}, asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, diff --git a/libgo/go/crypto/x509/x509_test.go b/libgo/go/crypto/x509/x509_test.go index d9511b863fb..a42113addda 100644 --- a/libgo/go/crypto/x509/x509_test.go +++ b/libgo/go/crypto/x509/x509_test.go @@ -20,12 +20,13 @@ func TestParsePKCS1PrivateKey(t *testing.T) { priv, err := ParsePKCS1PrivateKey(block.Bytes) if err != nil { t.Errorf("Failed to parse private key: %s", err) + return } if priv.PublicKey.N.Cmp(rsaPrivateKey.PublicKey.N) != 0 || priv.PublicKey.E != rsaPrivateKey.PublicKey.E || priv.D.Cmp(rsaPrivateKey.D) != 0 || - priv.P.Cmp(rsaPrivateKey.P) != 0 || - priv.Q.Cmp(rsaPrivateKey.Q) != 0 { + priv.Primes[0].Cmp(rsaPrivateKey.Primes[0]) != 0 || + priv.Primes[1].Cmp(rsaPrivateKey.Primes[1]) != 0 { t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey) } } @@ -47,14 +48,54 @@ func bigFromString(s string) *big.Int { return ret } +func fromBase10(base10 string) *big.Int { + i := new(big.Int) + i.SetString(base10, 10) + return i +} + var rsaPrivateKey = &rsa.PrivateKey{ PublicKey: rsa.PublicKey{ N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), E: 65537, }, D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + Primes: []*big.Int{ + bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + }, +} + +func TestMarshalRSAPrivateKey(t *testing.T) { + priv := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: fromBase10("16346378922382193400538269749936049106320265317511766357599732575277382844051791096569333808598921852351577762718529818072849191122419410612033592401403764925096136759934497687765453905884149505175426053037420486697072448609022753683683718057795566811401938833367954642951433473337066311978821180526439641496973296037000052546108507805269279414789035461158073156772151892452251106173507240488993608650881929629163465099476849643165682709047462010581308719577053905787496296934240246311806555924593059995202856826239801816771116902778517096212527979497399966526283516447337775509777558018145573127308919204297111496233"), + E: 3, + }, + D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), + Primes: []*big.Int{ + fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), + fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), + fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + }, + } + + derBytes := MarshalPKCS1PrivateKey(priv) + + priv2, err := ParsePKCS1PrivateKey(derBytes) + if err != nil { + t.Errorf("error parsing serialized key: %s", err) + return + } + if priv.PublicKey.N.Cmp(priv2.PublicKey.N) != 0 || + priv.PublicKey.E != priv2.PublicKey.E || + priv.D.Cmp(priv2.D) != 0 || + len(priv2.Primes) != 3 || + priv.Primes[0].Cmp(priv2.Primes[0]) != 0 || + priv.Primes[1].Cmp(priv2.Primes[1]) != 0 || + priv.Primes[2].Cmp(priv2.Primes[2]) != 0 { + t.Errorf("got:%+v want:%+v", priv, priv2) + } } type matchHostnamesTest struct { diff --git a/libgo/go/crypto/xtea/cipher.go b/libgo/go/crypto/xtea/cipher.go index b0fa2a1844d..f2a5da0035c 100644 --- a/libgo/go/crypto/xtea/cipher.go +++ b/libgo/go/crypto/xtea/cipher.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements XTEA encryption, as defined in Needham and -// Wheeler's 1997 technical report, "Tea extensions." +// Package xtea implements XTEA encryption, as defined in Needham and Wheeler's +// 1997 technical report, "Tea extensions." package xtea // For details, see http://www.cix.co.uk/~klockstone/xtea.pdf diff --git a/libgo/go/debug/dwarf/open.go b/libgo/go/debug/dwarf/open.go index cb009e0e09e..d9525f78835 100644 --- a/libgo/go/debug/dwarf/open.go +++ b/libgo/go/debug/dwarf/open.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides access to DWARF debugging information -// loaded from executable files, as defined in the DWARF 2.0 Standard -// at http://dwarfstd.org/doc/dwarf-2.0.0.pdf +// Package dwarf provides access to DWARF debugging information loaded from +// executable files, as defined in the DWARF 2.0 Standard at +// http://dwarfstd.org/doc/dwarf-2.0.0.pdf package dwarf import ( diff --git a/libgo/go/debug/elf/elf.go b/libgo/go/debug/elf/elf.go index 74e97998639..5d45b24863d 100644 --- a/libgo/go/debug/elf/elf.go +++ b/libgo/go/debug/elf/elf.go @@ -330,29 +330,35 @@ func (i SectionIndex) GoString() string { return stringName(uint32(i), shnString type SectionType uint32 const ( - SHT_NULL SectionType = 0 /* inactive */ - SHT_PROGBITS SectionType = 1 /* program defined information */ - SHT_SYMTAB SectionType = 2 /* symbol table section */ - SHT_STRTAB SectionType = 3 /* string table section */ - SHT_RELA SectionType = 4 /* relocation section with addends */ - SHT_HASH SectionType = 5 /* symbol hash table section */ - SHT_DYNAMIC SectionType = 6 /* dynamic section */ - SHT_NOTE SectionType = 7 /* note section */ - SHT_NOBITS SectionType = 8 /* no space section */ - SHT_REL SectionType = 9 /* relocation section - no addends */ - SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ - SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ - SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ - SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ - SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ - SHT_GROUP SectionType = 17 /* Section group. */ - SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ - SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ - SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ - SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ - SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ - SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ - SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ + SHT_NULL SectionType = 0 /* inactive */ + SHT_PROGBITS SectionType = 1 /* program defined information */ + SHT_SYMTAB SectionType = 2 /* symbol table section */ + SHT_STRTAB SectionType = 3 /* string table section */ + SHT_RELA SectionType = 4 /* relocation section with addends */ + SHT_HASH SectionType = 5 /* symbol hash table section */ + SHT_DYNAMIC SectionType = 6 /* dynamic section */ + SHT_NOTE SectionType = 7 /* note section */ + SHT_NOBITS SectionType = 8 /* no space section */ + SHT_REL SectionType = 9 /* relocation section - no addends */ + SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ + SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ + SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ + SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ + SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ + SHT_GROUP SectionType = 17 /* Section group. */ + SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ + SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ + SHT_GNU_ATTRIBUTES SectionType = 0x6ffffff5 /* GNU object attributes */ + SHT_GNU_HASH SectionType = 0x6ffffff6 /* GNU hash table */ + SHT_GNU_LIBLIST SectionType = 0x6ffffff7 /* GNU prelink library list */ + SHT_GNU_VERDEF SectionType = 0x6ffffffd /* GNU version definition section */ + SHT_GNU_VERNEED SectionType = 0x6ffffffe /* GNU version needs section */ + SHT_GNU_VERSYM SectionType = 0x6fffffff /* GNU version symbol table */ + SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ + SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ + SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ + SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ + SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ ) var shtStrings = []intName{ @@ -374,7 +380,12 @@ var shtStrings = []intName{ {17, "SHT_GROUP"}, {18, "SHT_SYMTAB_SHNDX"}, {0x60000000, "SHT_LOOS"}, - {0x6fffffff, "SHT_HIOS"}, + {0x6ffffff5, "SHT_GNU_ATTRIBUTES"}, + {0x6ffffff6, "SHT_GNU_HASH"}, + {0x6ffffff7, "SHT_GNU_LIBLIST"}, + {0x6ffffffd, "SHT_GNU_VERDEF"}, + {0x6ffffffe, "SHT_GNU_VERNEED"}, + {0x6fffffff, "SHT_GNU_VERSYM"}, {0x70000000, "SHT_LOPROC"}, {0x7fffffff, "SHT_HIPROC"}, {0x80000000, "SHT_LOUSER"}, @@ -518,6 +529,9 @@ const ( DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */ DT_LOOS DynTag = 0x6000000d /* First OS-specific */ DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */ + DT_VERSYM DynTag = 0x6ffffff0 + DT_VERNEED DynTag = 0x6ffffffe + DT_VERNEEDNUM DynTag = 0x6fffffff DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */ DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */ ) @@ -559,6 +573,9 @@ var dtStrings = []intName{ {33, "DT_PREINIT_ARRAYSZ"}, {0x6000000d, "DT_LOOS"}, {0x6ffff000, "DT_HIOS"}, + {0x6ffffff0, "DT_VERSYM"}, + {0x6ffffffe, "DT_VERNEED"}, + {0x6fffffff, "DT_VERNEEDNUM"}, {0x70000000, "DT_LOPROC"}, {0x7fffffff, "DT_HIPROC"}, } diff --git a/libgo/go/debug/elf/file.go b/libgo/go/debug/elf/file.go index 6fdcda6d485..9ae8b413d91 100644 --- a/libgo/go/debug/elf/file.go +++ b/libgo/go/debug/elf/file.go @@ -35,9 +35,11 @@ type FileHeader struct { // A File represents an open ELF file. type File struct { FileHeader - Sections []*Section - Progs []*Prog - closer io.Closer + Sections []*Section + Progs []*Prog + closer io.Closer + gnuNeed []verneed + gnuVersym []byte } // A SectionHeader represents a single ELF section header. @@ -329,8 +331,8 @@ func NewFile(r io.ReaderAt) (*File, os.Error) { } // getSymbols returns a slice of Symbols from parsing the symbol table -// with the given type. -func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { +// with the given type, along with the associated string table. +func (f *File) getSymbols(typ SectionType) ([]Symbol, []byte, os.Error) { switch f.Class { case ELFCLASS64: return f.getSymbols64(typ) @@ -339,27 +341,27 @@ func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { return f.getSymbols32(typ) } - return nil, os.ErrorString("not implemented") + return nil, nil, os.ErrorString("not implemented") } -func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { +func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, os.Error) { symtabSection := f.SectionByType(typ) if symtabSection == nil { - return nil, os.ErrorString("no symbol section") + return nil, nil, os.ErrorString("no symbol section") } data, err := symtabSection.Data() if err != nil { - return nil, os.ErrorString("cannot load symbol section") + return nil, nil, os.ErrorString("cannot load symbol section") } symtab := bytes.NewBuffer(data) if symtab.Len()%Sym32Size != 0 { - return nil, os.ErrorString("length of symbol section is not a multiple of SymSize") + return nil, nil, os.ErrorString("length of symbol section is not a multiple of SymSize") } strdata, err := f.stringTable(symtabSection.Link) if err != nil { - return nil, os.ErrorString("cannot load string table section") + return nil, nil, os.ErrorString("cannot load string table section") } // The first entry is all zeros. @@ -382,27 +384,27 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { i++ } - return symbols, nil + return symbols, strdata, nil } -func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { +func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, os.Error) { symtabSection := f.SectionByType(typ) if symtabSection == nil { - return nil, os.ErrorString("no symbol section") + return nil, nil, os.ErrorString("no symbol section") } data, err := symtabSection.Data() if err != nil { - return nil, os.ErrorString("cannot load symbol section") + return nil, nil, os.ErrorString("cannot load symbol section") } symtab := bytes.NewBuffer(data) if symtab.Len()%Sym64Size != 0 { - return nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") + return nil, nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") } strdata, err := f.stringTable(symtabSection.Link) if err != nil { - return nil, os.ErrorString("cannot load string table section") + return nil, nil, os.ErrorString("cannot load string table section") } // The first entry is all zeros. @@ -425,7 +427,7 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { i++ } - return symbols, nil + return symbols, strdata, nil } // getString extracts a string from an ELF string table. @@ -468,7 +470,7 @@ func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) os.Error { return os.ErrorString("length of relocation section is not a multiple of Sym64Size") } - symbols, err := f.getSymbols(SHT_SYMTAB) + symbols, _, err := f.getSymbols(SHT_SYMTAB) if err != nil { return err } @@ -544,24 +546,123 @@ func (f *File) DWARF() (*dwarf.Data, os.Error) { return dwarf.New(abbrev, nil, nil, info, nil, nil, nil, str) } +type ImportedSymbol struct { + Name string + Version string + Library string +} + // ImportedSymbols returns the names of all symbols // referred to by the binary f that are expected to be // satisfied by other libraries at dynamic load time. // It does not return weak symbols. -func (f *File) ImportedSymbols() ([]string, os.Error) { - sym, err := f.getSymbols(SHT_DYNSYM) +func (f *File) ImportedSymbols() ([]ImportedSymbol, os.Error) { + sym, str, err := f.getSymbols(SHT_DYNSYM) if err != nil { return nil, err } - var all []string - for _, s := range sym { + f.gnuVersionInit(str) + var all []ImportedSymbol + for i, s := range sym { if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF { - all = append(all, s.Name) + all = append(all, ImportedSymbol{Name: s.Name}) + f.gnuVersion(i, &all[len(all)-1]) } } return all, nil } +type verneed struct { + File string + Name string +} + +// gnuVersionInit parses the GNU version tables +// for use by calls to gnuVersion. +func (f *File) gnuVersionInit(str []byte) { + // Accumulate verneed information. + vn := f.SectionByType(SHT_GNU_VERNEED) + if vn == nil { + return + } + d, _ := vn.Data() + + var need []verneed + i := 0 + for { + if i+16 > len(d) { + break + } + vers := f.ByteOrder.Uint16(d[i : i+2]) + if vers != 1 { + break + } + cnt := f.ByteOrder.Uint16(d[i+2 : i+4]) + fileoff := f.ByteOrder.Uint32(d[i+4 : i+8]) + aux := f.ByteOrder.Uint32(d[i+8 : i+12]) + next := f.ByteOrder.Uint32(d[i+12 : i+16]) + file, _ := getString(str, int(fileoff)) + + var name string + j := i + int(aux) + for c := 0; c < int(cnt); c++ { + if j+16 > len(d) { + break + } + // hash := f.ByteOrder.Uint32(d[j:j+4]) + // flags := f.ByteOrder.Uint16(d[j+4:j+6]) + other := f.ByteOrder.Uint16(d[j+6 : j+8]) + nameoff := f.ByteOrder.Uint32(d[j+8 : j+12]) + next := f.ByteOrder.Uint32(d[j+12 : j+16]) + name, _ = getString(str, int(nameoff)) + ndx := int(other) + if ndx >= len(need) { + a := make([]verneed, 2*(ndx+1)) + copy(a, need) + need = a + } + + need[ndx] = verneed{file, name} + if next == 0 { + break + } + j += int(next) + } + + if next == 0 { + break + } + i += int(next) + } + + // Versym parallels symbol table, indexing into verneed. + vs := f.SectionByType(SHT_GNU_VERSYM) + if vs == nil { + return + } + d, _ = vs.Data() + + f.gnuNeed = need + f.gnuVersym = d +} + +// gnuVersion adds Library and Version information to sym, +// which came from offset i of the symbol table. +func (f *File) gnuVersion(i int, sym *ImportedSymbol) { + // Each entry is two bytes; skip undef entry at beginning. + i = (i + 1) * 2 + if i >= len(f.gnuVersym) { + return + } + j := int(f.ByteOrder.Uint16(f.gnuVersym[i:])) + if j < 2 || j >= len(f.gnuNeed) { + return + } + n := &f.gnuNeed[j] + sym.Library = n.File + sym.Version = n.Name +} + // ImportedLibraries returns the names of all libraries // referred to by the binary f that are expected to be // linked with the binary at dynamic link time. diff --git a/libgo/go/ebnf/ebnf.go b/libgo/go/ebnf/ebnf.go index e5aabd582b3..7918c4593bb 100644 --- a/libgo/go/ebnf/ebnf.go +++ b/libgo/go/ebnf/ebnf.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A library for EBNF grammars. The input is text ([]byte) satisfying -// the following grammar (represented itself in EBNF): +// Package ebnf is a library for EBNF grammars. The input is text ([]byte) +// satisfying the following grammar (represented itself in EBNF): // // Production = name "=" Expression "." . // Expression = Alternative { "|" Alternative } . diff --git a/libgo/go/encoding/binary/binary.go b/libgo/go/encoding/binary/binary.go index ee2f23dbba2..a01d0e02464 100644 --- a/libgo/go/encoding/binary/binary.go +++ b/libgo/go/encoding/binary/binary.go @@ -126,10 +126,10 @@ func (bigEndian) GoString() string { return "binary.BigEndian" } // and written to successive fields of the data. func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { var v reflect.Value - switch d := reflect.NewValue(data).(type) { - case *reflect.PtrValue: + switch d := reflect.ValueOf(data); d.Kind() { + case reflect.Ptr: v = d.Elem() - case *reflect.SliceValue: + case reflect.Slice: v = d default: return os.NewError("binary.Read: invalid type " + d.Type().String()) @@ -155,7 +155,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { // Bytes written to w are encoded using the specified byte order // and read from successive fields of the data. func Write(w io.Writer, order ByteOrder, data interface{}) os.Error { - v := reflect.Indirect(reflect.NewValue(data)) + v := reflect.Indirect(reflect.ValueOf(data)) size := TotalSize(v) if size < 0 { return os.NewError("binary.Write: invalid type " + v.Type().String()) @@ -168,26 +168,26 @@ func Write(w io.Writer, order ByteOrder, data interface{}) os.Error { } func TotalSize(v reflect.Value) int { - if sv, ok := v.(*reflect.SliceValue); ok { - elem := sizeof(v.Type().(*reflect.SliceType).Elem()) + if v.Kind() == reflect.Slice { + elem := sizeof(v.Type().Elem()) if elem < 0 { return -1 } - return sv.Len() * elem + return v.Len() * elem } return sizeof(v.Type()) } -func sizeof(v reflect.Type) int { - switch t := v.(type) { - case *reflect.ArrayType: +func sizeof(t reflect.Type) int { + switch t.Kind() { + case reflect.Array: n := sizeof(t.Elem()) if n < 0 { return -1 } return t.Len() * n - case *reflect.StructType: + case reflect.Struct: sum := 0 for i, n := 0, t.NumField(); i < n; i++ { s := sizeof(t.Field(i).Type) @@ -198,12 +198,10 @@ func sizeof(v reflect.Type) int { } return sum - case *reflect.UintType, *reflect.IntType, *reflect.FloatType, *reflect.ComplexType: - switch t := t.Kind(); t { - case reflect.Int, reflect.Uint, reflect.Uintptr: - return -1 - } - return int(v.Size()) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return int(t.Size()) } return -1 } @@ -279,130 +277,118 @@ func (d *decoder) int64() int64 { return int64(d.uint64()) } func (e *encoder) int64(x int64) { e.uint64(uint64(x)) } func (d *decoder) value(v reflect.Value) { - switch v := v.(type) { - case *reflect.ArrayValue: + switch v.Kind() { + case reflect.Array: l := v.Len() for i := 0; i < l; i++ { - d.value(v.Elem(i)) + d.value(v.Index(i)) } - case *reflect.StructValue: + case reflect.Struct: l := v.NumField() for i := 0; i < l; i++ { d.value(v.Field(i)) } - case *reflect.SliceValue: + case reflect.Slice: l := v.Len() for i := 0; i < l; i++ { - d.value(v.Elem(i)) + d.value(v.Index(i)) } - case *reflect.IntValue: - switch v.Type().Kind() { - case reflect.Int8: - v.Set(int64(d.int8())) - case reflect.Int16: - v.Set(int64(d.int16())) - case reflect.Int32: - v.Set(int64(d.int32())) - case reflect.Int64: - v.Set(d.int64()) - } - - case *reflect.UintValue: - switch v.Type().Kind() { - case reflect.Uint8: - v.Set(uint64(d.uint8())) - case reflect.Uint16: - v.Set(uint64(d.uint16())) - case reflect.Uint32: - v.Set(uint64(d.uint32())) - case reflect.Uint64: - v.Set(d.uint64()) - } - - case *reflect.FloatValue: - switch v.Type().Kind() { - case reflect.Float32: - v.Set(float64(math.Float32frombits(d.uint32()))) - case reflect.Float64: - v.Set(math.Float64frombits(d.uint64())) - } - - case *reflect.ComplexValue: - switch v.Type().Kind() { - case reflect.Complex64: - v.Set(complex( - float64(math.Float32frombits(d.uint32())), - float64(math.Float32frombits(d.uint32())), - )) - case reflect.Complex128: - v.Set(complex( - math.Float64frombits(d.uint64()), - math.Float64frombits(d.uint64()), - )) - } + case reflect.Int8: + v.SetInt(int64(d.int8())) + case reflect.Int16: + v.SetInt(int64(d.int16())) + case reflect.Int32: + v.SetInt(int64(d.int32())) + case reflect.Int64: + v.SetInt(d.int64()) + + case reflect.Uint8: + v.SetUint(uint64(d.uint8())) + case reflect.Uint16: + v.SetUint(uint64(d.uint16())) + case reflect.Uint32: + v.SetUint(uint64(d.uint32())) + case reflect.Uint64: + v.SetUint(d.uint64()) + + case reflect.Float32: + v.SetFloat(float64(math.Float32frombits(d.uint32()))) + case reflect.Float64: + v.SetFloat(math.Float64frombits(d.uint64())) + + case reflect.Complex64: + v.SetComplex(complex( + float64(math.Float32frombits(d.uint32())), + float64(math.Float32frombits(d.uint32())), + )) + case reflect.Complex128: + v.SetComplex(complex( + math.Float64frombits(d.uint64()), + math.Float64frombits(d.uint64()), + )) } } func (e *encoder) value(v reflect.Value) { - switch v := v.(type) { - case *reflect.ArrayValue: + switch v.Kind() { + case reflect.Array: l := v.Len() for i := 0; i < l; i++ { - e.value(v.Elem(i)) + e.value(v.Index(i)) } - case *reflect.StructValue: + case reflect.Struct: l := v.NumField() for i := 0; i < l; i++ { e.value(v.Field(i)) } - case *reflect.SliceValue: + case reflect.Slice: l := v.Len() for i := 0; i < l; i++ { - e.value(v.Elem(i)) + e.value(v.Index(i)) } - case *reflect.IntValue: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: switch v.Type().Kind() { case reflect.Int8: - e.int8(int8(v.Get())) + e.int8(int8(v.Int())) case reflect.Int16: - e.int16(int16(v.Get())) + e.int16(int16(v.Int())) case reflect.Int32: - e.int32(int32(v.Get())) + e.int32(int32(v.Int())) case reflect.Int64: - e.int64(v.Get()) + e.int64(v.Int()) } - case *reflect.UintValue: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: switch v.Type().Kind() { case reflect.Uint8: - e.uint8(uint8(v.Get())) + e.uint8(uint8(v.Uint())) case reflect.Uint16: - e.uint16(uint16(v.Get())) + e.uint16(uint16(v.Uint())) case reflect.Uint32: - e.uint32(uint32(v.Get())) + e.uint32(uint32(v.Uint())) case reflect.Uint64: - e.uint64(v.Get()) + e.uint64(v.Uint()) } - case *reflect.FloatValue: + case reflect.Float32, reflect.Float64: switch v.Type().Kind() { case reflect.Float32: - e.uint32(math.Float32bits(float32(v.Get()))) + e.uint32(math.Float32bits(float32(v.Float()))) case reflect.Float64: - e.uint64(math.Float64bits(v.Get())) + e.uint64(math.Float64bits(v.Float())) } - case *reflect.ComplexValue: + case reflect.Complex64, reflect.Complex128: switch v.Type().Kind() { case reflect.Complex64: - x := v.Get() + x := v.Complex() e.uint32(math.Float32bits(float32(real(x)))) e.uint32(math.Float32bits(float32(imag(x)))) case reflect.Complex128: - x := v.Get() + x := v.Complex() e.uint64(math.Float64bits(real(x))) e.uint64(math.Float64bits(imag(x))) } diff --git a/libgo/go/encoding/binary/binary_test.go b/libgo/go/encoding/binary/binary_test.go index e09ec489fd4..7857c68d36e 100644 --- a/libgo/go/encoding/binary/binary_test.go +++ b/libgo/go/encoding/binary/binary_test.go @@ -152,7 +152,7 @@ func TestWriteT(t *testing.T) { t.Errorf("WriteT: have nil, want non-nil") } - tv := reflect.Indirect(reflect.NewValue(ts)).(*reflect.StructValue) + tv := reflect.Indirect(reflect.ValueOf(ts)) for i, n := 0, tv.NumField(); i < n; i++ { err = Write(buf, BigEndian, tv.Field(i).Interface()) if err == nil { diff --git a/libgo/go/encoding/hex/hex.go b/libgo/go/encoding/hex/hex.go index 292d917eb4c..891de186107 100644 --- a/libgo/go/encoding/hex/hex.go +++ b/libgo/go/encoding/hex/hex.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements hexadecimal encoding and decoding. +// Package hex implements hexadecimal encoding and decoding. package hex import ( diff --git a/libgo/go/encoding/line/line.go b/libgo/go/encoding/line/line.go index f46ce1c83a0..123962b1f91 100644 --- a/libgo/go/encoding/line/line.go +++ b/libgo/go/encoding/line/line.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The line package implements a Reader that reads lines delimited by '\n' or ' \r\n'. +// Package line implements a Reader that reads lines delimited by '\n' or +// ' \r\n'. package line import ( diff --git a/libgo/go/encoding/pem/pem.go b/libgo/go/encoding/pem/pem.go index 5653aeb77c7..44e3d0ad094 100644 --- a/libgo/go/encoding/pem/pem.go +++ b/libgo/go/encoding/pem/pem.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the PEM data encoding, which originated in Privacy +// Package pem implements the PEM data encoding, which originated in Privacy // Enhanced Mail. The most common use of PEM encoding today is in TLS keys and // certificates. See RFC 1421. package pem diff --git a/libgo/go/exec/exec.go b/libgo/go/exec/exec.go index 5398eb8e0ca..043f847283e 100644 --- a/libgo/go/exec/exec.go +++ b/libgo/go/exec/exec.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The exec package runs external commands. It wraps os.StartProcess -// to make it easier to remap stdin and stdout, connect I/O with pipes, -// and do other adjustments. +// Package exec runs external commands. It wraps os.StartProcess to make it +// easier to remap stdin and stdout, connect I/O with pipes, and do other +// adjustments. package exec // BUG(r): This package should be made even easier to use or merged into os. diff --git a/libgo/go/exec/exec_test.go b/libgo/go/exec/exec_test.go index 5e37b99eeca..eb8cd5fec9f 100644 --- a/libgo/go/exec/exec_test.go +++ b/libgo/go/exec/exec_test.go @@ -9,19 +9,14 @@ import ( "io/ioutil" "testing" "os" - "runtime" ) func run(argv []string, stdin, stdout, stderr int) (p *Cmd, err os.Error) { - if runtime.GOOS == "windows" { - argv = append([]string{"cmd", "/c"}, argv...) - } exe, err := LookPath(argv[0]) if err != nil { return nil, err } - p, err = Run(exe, argv, nil, "", stdin, stdout, stderr) - return p, err + return Run(exe, argv, nil, "", stdin, stdout, stderr) } func TestRunCat(t *testing.T) { diff --git a/libgo/go/exp/datafmt/datafmt.go b/libgo/go/exp/datafmt/datafmt.go index 46c412342ad..a8efdc58fe9 100644 --- a/libgo/go/exp/datafmt/datafmt.go +++ b/libgo/go/exp/datafmt/datafmt.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -/* The datafmt package implements syntax-directed, type-driven formatting +/* Package datafmt implements syntax-directed, type-driven formatting of arbitrary data structures. Formatting a data structure consists of two phases: first, a parser reads a format specification and builds a "compiled" format. Then, the format can be applied repeatedly to @@ -408,20 +408,20 @@ func (s *State) error(msg string) { // func typename(typ reflect.Type) string { - switch typ.(type) { - case *reflect.ArrayType: + switch typ.Kind() { + case reflect.Array: return "array" - case *reflect.SliceType: + case reflect.Slice: return "array" - case *reflect.ChanType: + case reflect.Chan: return "chan" - case *reflect.FuncType: + case reflect.Func: return "func" - case *reflect.InterfaceType: + case reflect.Interface: return "interface" - case *reflect.MapType: + case reflect.Map: return "map" - case *reflect.PtrType: + case reflect.Ptr: return "ptr" } return typ.String() @@ -519,38 +519,38 @@ func (s *State) eval(fexpr expr, value reflect.Value, index int) bool { case "*": // indirection: operation is type-specific - switch v := value.(type) { - case *reflect.ArrayValue: + switch v := value; v.Kind() { + case reflect.Array: if v.Len() <= index { return false } - value = v.Elem(index) + value = v.Index(index) - case *reflect.SliceValue: + case reflect.Slice: if v.IsNil() || v.Len() <= index { return false } - value = v.Elem(index) + value = v.Index(index) - case *reflect.MapValue: + case reflect.Map: s.error("reflection support for maps incomplete") - case *reflect.PtrValue: + case reflect.Ptr: if v.IsNil() { return false } value = v.Elem() - case *reflect.InterfaceValue: + case reflect.Interface: if v.IsNil() { return false } value = v.Elem() - case *reflect.ChanValue: + case reflect.Chan: s.error("reflection support for chans incomplete") - case *reflect.FuncValue: + case reflect.Func: s.error("reflection support for funcs incomplete") default: @@ -560,9 +560,9 @@ func (s *State) eval(fexpr expr, value reflect.Value, index int) bool { default: // value is value of named field var field reflect.Value - if sval, ok := value.(*reflect.StructValue); ok { + if sval := value; sval.Kind() == reflect.Struct { field = sval.FieldByName(t.fieldName) - if field == nil { + if !field.IsValid() { // TODO consider just returning false in this case s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type())) } @@ -671,8 +671,8 @@ func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { go func() { for _, v := range args { - fld := reflect.NewValue(v) - if fld == nil { + fld := reflect.ValueOf(v) + if !fld.IsValid() { errors <- os.NewError("nil argument") return } diff --git a/libgo/go/exp/draw/draw.go b/libgo/go/exp/draw/draw.go index 1d0729d922c..f98e2461894 100644 --- a/libgo/go/exp/draw/draw.go +++ b/libgo/go/exp/draw/draw.go @@ -8,7 +8,10 @@ // and the X Render extension. package draw -import "image" +import ( + "image" + "image/ycbcr" +) // m is the maximum color value returned by image.Color.RGBA. const m = 1<<16 - 1 @@ -65,29 +68,42 @@ func DrawMask(dst Image, r image.Rectangle, src image.Image, sp image.Point, mas if dst0, ok := dst.(*image.RGBA); ok { if op == Over { if mask == nil { - if src0, ok := src.(*image.ColorImage); ok { + switch src0 := src.(type) { + case *image.ColorImage: drawFillOver(dst0, r, src0) return - } - if src0, ok := src.(*image.RGBA); ok { + case *image.RGBA: drawCopyOver(dst0, r, src0, sp) return + case *image.NRGBA: + drawNRGBAOver(dst0, r, src0, sp) + return + case *ycbcr.YCbCr: + drawYCbCr(dst0, r, src0, sp) + return } } else if mask0, ok := mask.(*image.Alpha); ok { - if src0, ok := src.(*image.ColorImage); ok { + switch src0 := src.(type) { + case *image.ColorImage: drawGlyphOver(dst0, r, src0, mask0, mp) return } } } else { if mask == nil { - if src0, ok := src.(*image.ColorImage); ok { + switch src0 := src.(type) { + case *image.ColorImage: drawFillSrc(dst0, r, src0) return - } - if src0, ok := src.(*image.RGBA); ok { + case *image.RGBA: drawCopySrc(dst0, r, src0, sp) return + case *image.NRGBA: + drawNRGBASrc(dst0, r, src0, sp) + return + case *ycbcr.YCbCr: + drawYCbCr(dst0, r, src0, sp) + return } } } @@ -224,6 +240,36 @@ func drawCopyOver(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image. } } +func drawNRGBAOver(dst *image.RGBA, r image.Rectangle, src *image.NRGBA, sp image.Point) { + for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 { + dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride] + spix := src.Pix[sy*src.Stride : (sy+1)*src.Stride] + for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 { + // Convert from non-premultiplied color to pre-multiplied color. + // The order of operations here is to match the NRGBAColor.RGBA + // method in image/color.go. + snrgba := spix[sx] + sa := uint32(snrgba.A) + sr := uint32(snrgba.R) * 0x101 * sa / 0xff + sg := uint32(snrgba.G) * 0x101 * sa / 0xff + sb := uint32(snrgba.B) * 0x101 * sa / 0xff + sa *= 0x101 + + rgba := dpix[x] + dr := uint32(rgba.R) + dg := uint32(rgba.G) + db := uint32(rgba.B) + da := uint32(rgba.A) + a := (m - sa) * 0x101 + dr = (dr*a + sr*m) / m + dg = (dg*a + sg*m) / m + db = (db*a + sb*m) / m + da = (da*a + sa*m) / m + dpix[x] = image.RGBAColor{uint8(dr >> 8), uint8(dg >> 8), uint8(db >> 8), uint8(da >> 8)} + } + } +} + func drawGlyphOver(dst *image.RGBA, r image.Rectangle, src *image.ColorImage, mask *image.Alpha, mp image.Point) { x0, x1 := r.Min.X, r.Max.X y0, y1 := r.Min.Y, r.Max.Y @@ -311,6 +357,73 @@ func drawCopySrc(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image.P } } +func drawNRGBASrc(dst *image.RGBA, r image.Rectangle, src *image.NRGBA, sp image.Point) { + for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 { + dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride] + spix := src.Pix[sy*src.Stride : (sy+1)*src.Stride] + for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 { + // Convert from non-premultiplied color to pre-multiplied color. + // The order of operations here is to match the NRGBAColor.RGBA + // method in image/color.go. + snrgba := spix[sx] + sa := uint32(snrgba.A) + sr := uint32(snrgba.R) * 0x101 * sa / 0xff + sg := uint32(snrgba.G) * 0x101 * sa / 0xff + sb := uint32(snrgba.B) * 0x101 * sa / 0xff + sa *= 0x101 + + dpix[x] = image.RGBAColor{uint8(sr >> 8), uint8(sg >> 8), uint8(sb >> 8), uint8(sa >> 8)} + } + } +} + +func drawYCbCr(dst *image.RGBA, r image.Rectangle, src *ycbcr.YCbCr, sp image.Point) { + // A YCbCr image is always fully opaque, and so if the mask is implicitly nil + // (i.e. fully opaque) then the op is effectively always Src. + var ( + yy, cb, cr uint8 + rr, gg, bb uint8 + ) + switch src.SubsampleRatio { + case ycbcr.SubsampleRatio422: + for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 { + dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride] + for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 { + i := sx / 2 + yy = src.Y[sy*src.YStride+sx] + cb = src.Cb[sy*src.CStride+i] + cr = src.Cr[sy*src.CStride+i] + rr, gg, bb = ycbcr.YCbCrToRGB(yy, cb, cr) + dpix[x] = image.RGBAColor{rr, gg, bb, 255} + } + } + case ycbcr.SubsampleRatio420: + for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 { + dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride] + for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 { + i, j := sx/2, sy/2 + yy = src.Y[sy*src.YStride+sx] + cb = src.Cb[j*src.CStride+i] + cr = src.Cr[j*src.CStride+i] + rr, gg, bb = ycbcr.YCbCrToRGB(yy, cb, cr) + dpix[x] = image.RGBAColor{rr, gg, bb, 255} + } + } + default: + // Default to 4:4:4 subsampling. + for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 { + dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride] + for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 { + yy = src.Y[sy*src.YStride+sx] + cb = src.Cb[sy*src.CStride+sx] + cr = src.Cr[sy*src.CStride+sx] + rr, gg, bb = ycbcr.YCbCrToRGB(yy, cb, cr) + dpix[x] = image.RGBAColor{rr, gg, bb, 255} + } + } + } +} + func drawRGBA(dst *image.RGBA, r image.Rectangle, src image.Image, sp image.Point, mask image.Image, mp image.Point, op Op) { x0, x1, dx := r.Min.X, r.Max.X, 1 y0, y1, dy := r.Min.Y, r.Max.Y, 1 diff --git a/libgo/go/exp/draw/draw_test.go b/libgo/go/exp/draw/draw_test.go index 90c9e823d3e..873a2f24a40 100644 --- a/libgo/go/exp/draw/draw_test.go +++ b/libgo/go/exp/draw/draw_test.go @@ -6,6 +6,7 @@ package draw import ( "image" + "image/ycbcr" "testing" ) @@ -43,6 +44,34 @@ func vgradAlpha(alpha int) image.Image { return m } +func vgradGreenNRGBA(alpha int) image.Image { + m := image.NewNRGBA(16, 16) + for y := 0; y < 16; y++ { + for x := 0; x < 16; x++ { + m.Set(x, y, image.RGBAColor{0, uint8(y * 0x11), 0, uint8(alpha)}) + } + } + return m +} + +func vgradCr() image.Image { + m := &ycbcr.YCbCr{ + Y: make([]byte, 16*16), + Cb: make([]byte, 16*16), + Cr: make([]byte, 16*16), + YStride: 16, + CStride: 16, + SubsampleRatio: ycbcr.SubsampleRatio444, + Rect: image.Rect(0, 0, 16, 16), + } + for y := 0; y < 16; y++ { + for x := 0; x < 16; x++ { + m.Cr[y*m.CStride+x] = uint8(y * 0x11) + } + } + return m +} + func hgradRed(alpha int) Image { m := image.NewRGBA(16, 16) for y := 0; y < 16; y++ { @@ -95,6 +124,27 @@ var drawTests = []drawTest{ {"copyAlphaSrc", vgradGreen(90), fillAlpha(192), Src, image.RGBAColor{0, 36, 0, 68}}, {"copyNil", vgradGreen(90), nil, Over, image.RGBAColor{88, 48, 0, 255}}, {"copyNilSrc", vgradGreen(90), nil, Src, image.RGBAColor{0, 48, 0, 90}}, + // Uniform mask (100%, 75%, nil) and variable NRGBA source. + // At (x, y) == (8, 8): + // The destination pixel is {136, 0, 0, 255}. + // The source pixel is {0, 136, 0, 90} in NRGBA-space, which is {0, 48, 0, 90} in RGBA-space. + // The result pixel is different than in the "copy*" test cases because of rounding errors. + {"nrgba", vgradGreenNRGBA(90), fillAlpha(255), Over, image.RGBAColor{88, 46, 0, 255}}, + {"nrgbaSrc", vgradGreenNRGBA(90), fillAlpha(255), Src, image.RGBAColor{0, 46, 0, 90}}, + {"nrgbaAlpha", vgradGreenNRGBA(90), fillAlpha(192), Over, image.RGBAColor{100, 34, 0, 255}}, + {"nrgbaAlphaSrc", vgradGreenNRGBA(90), fillAlpha(192), Src, image.RGBAColor{0, 34, 0, 68}}, + {"nrgbaNil", vgradGreenNRGBA(90), nil, Over, image.RGBAColor{88, 46, 0, 255}}, + {"nrgbaNilSrc", vgradGreenNRGBA(90), nil, Src, image.RGBAColor{0, 46, 0, 90}}, + // Uniform mask (100%, 75%, nil) and variable YCbCr source. + // At (x, y) == (8, 8): + // The destination pixel is {136, 0, 0, 255}. + // The source pixel is {0, 0, 136} in YCbCr-space, which is {11, 38, 0, 255} in RGB-space. + {"ycbcr", vgradCr(), fillAlpha(255), Over, image.RGBAColor{11, 38, 0, 255}}, + {"ycbcrSrc", vgradCr(), fillAlpha(255), Src, image.RGBAColor{11, 38, 0, 255}}, + {"ycbcrAlpha", vgradCr(), fillAlpha(192), Over, image.RGBAColor{42, 28, 0, 255}}, + {"ycbcrAlphaSrc", vgradCr(), fillAlpha(192), Src, image.RGBAColor{8, 28, 0, 192}}, + {"ycbcrNil", vgradCr(), nil, Over, image.RGBAColor{11, 38, 0, 255}}, + {"ycbcrNilSrc", vgradCr(), nil, Src, image.RGBAColor{11, 38, 0, 255}}, // Variable mask and variable source. // At (x, y) == (8, 8): // The destination pixel is {136, 0, 0, 255}. diff --git a/libgo/go/exp/draw/x11/conn.go b/libgo/go/exp/draw/x11/conn.go index 53294af15c0..81c67267db6 100644 --- a/libgo/go/exp/draw/x11/conn.go +++ b/libgo/go/exp/draw/x11/conn.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements an X11 backend for the exp/draw package. +// Package x11 implements an X11 backend for the exp/draw package. // // The X protocol specification is at ftp://ftp.x.org/pub/X11R7.0/doc/PDF/proto.pdf. // A summary of the wire format can be found in XCB's xproto.xml. diff --git a/libgo/go/exp/eval/bridge.go b/libgo/go/exp/eval/bridge.go index 12835c4c028..f31d9ab9bd6 100644 --- a/libgo/go/exp/eval/bridge.go +++ b/libgo/go/exp/eval/bridge.go @@ -34,54 +34,49 @@ func TypeFromNative(t reflect.Type) Type { } var et Type - switch t := t.(type) { - case *reflect.BoolType: + switch t.Kind() { + case reflect.Bool: et = BoolType - case *reflect.FloatType: - switch t.Kind() { - case reflect.Float32: - et = Float32Type - case reflect.Float64: - et = Float64Type - } - case *reflect.IntType: - switch t.Kind() { - case reflect.Int16: - et = Int16Type - case reflect.Int32: - et = Int32Type - case reflect.Int64: - et = Int64Type - case reflect.Int8: - et = Int8Type - case reflect.Int: - et = IntType - } - case *reflect.UintType: - switch t.Kind() { - case reflect.Uint16: - et = Uint16Type - case reflect.Uint32: - et = Uint32Type - case reflect.Uint64: - et = Uint64Type - case reflect.Uint8: - et = Uint8Type - case reflect.Uint: - et = UintType - case reflect.Uintptr: - et = UintptrType - } - case *reflect.StringType: + + case reflect.Float32: + et = Float32Type + case reflect.Float64: + et = Float64Type + + case reflect.Int16: + et = Int16Type + case reflect.Int32: + et = Int32Type + case reflect.Int64: + et = Int64Type + case reflect.Int8: + et = Int8Type + case reflect.Int: + et = IntType + + case reflect.Uint16: + et = Uint16Type + case reflect.Uint32: + et = Uint32Type + case reflect.Uint64: + et = Uint64Type + case reflect.Uint8: + et = Uint8Type + case reflect.Uint: + et = UintType + case reflect.Uintptr: + et = UintptrType + + case reflect.String: et = StringType - case *reflect.ArrayType: + case reflect.Array: et = NewArrayType(int64(t.Len()), TypeFromNative(t.Elem())) - case *reflect.ChanType: + case reflect.Chan: log.Panicf("%T not implemented", t) - case *reflect.FuncType: + case reflect.Func: nin := t.NumIn() // Variadic functions have DotDotDotType at the end - variadic := t.DotDotDot() + variadic := t.IsVariadic() if variadic { nin-- } @@ -94,15 +89,15 @@ func TypeFromNative(t reflect.Type) Type { out[i] = TypeFromNative(t.Out(i)) } et = NewFuncType(in, variadic, out) - case *reflect.InterfaceType: + case reflect.Interface: log.Panicf("%T not implemented", t) - case *reflect.MapType: + case reflect.Map: log.Panicf("%T not implemented", t) - case *reflect.PtrType: + case reflect.Ptr: et = NewPtrType(TypeFromNative(t.Elem())) - case *reflect.SliceType: + case reflect.Slice: et = NewSliceType(TypeFromNative(t.Elem())) - case *reflect.StructType: + case reflect.Struct: n := t.NumField() fields := make([]StructField, n) for i := 0; i < n; i++ { @@ -113,7 +108,7 @@ func TypeFromNative(t reflect.Type) Type { fields[i].Anonymous = sf.Anonymous } et = NewStructType(fields) - case *reflect.UnsafePointerType: + case reflect.UnsafePointer: log.Panicf("%T not implemented", t) default: log.Panicf("unexpected reflect.Type: %T", t) @@ -133,7 +128,7 @@ func TypeFromNative(t reflect.Type) Type { } // TypeOfNative returns the interpreter Type of a regular Go value. -func TypeOfNative(v interface{}) Type { return TypeFromNative(reflect.Typeof(v)) } +func TypeOfNative(v interface{}) Type { return TypeFromNative(reflect.TypeOf(v)) } /* * Function bridging diff --git a/libgo/go/exp/eval/type.go b/libgo/go/exp/eval/type.go index 3f272ce4b6c..8a93d8a6c27 100644 --- a/libgo/go/exp/eval/type.go +++ b/libgo/go/exp/eval/type.go @@ -86,7 +86,7 @@ func hashTypeArray(key []Type) uintptr { if t == nil { continue } - addr := reflect.NewValue(t).(*reflect.PtrValue).Get() + addr := reflect.ValueOf(t).Pointer() hash ^= addr } return hash diff --git a/libgo/go/exp/eval/world.go b/libgo/go/exp/eval/world.go index 02d18bd7935..a5f6ac7e5e7 100644 --- a/libgo/go/exp/eval/world.go +++ b/libgo/go/exp/eval/world.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package is the beginning of an interpreter for Go. +// Package eval is the beginning of an interpreter for Go. // It can run simple Go programs but does not implement // interface values or packages. package eval diff --git a/libgo/go/exp/ogle/cmd.go b/libgo/go/exp/ogle/cmd.go index 813d3a875a6..a8db523ea18 100644 --- a/libgo/go/exp/ogle/cmd.go +++ b/libgo/go/exp/ogle/cmd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Ogle is the beginning of a debugger for Go. +// Package ogle is the beginning of a debugger for Go. package ogle import ( diff --git a/libgo/go/exp/ogle/process.go b/libgo/go/exp/ogle/process.go index 58e830aa68b..7c803b3a27e 100644 --- a/libgo/go/exp/ogle/process.go +++ b/libgo/go/exp/ogle/process.go @@ -226,8 +226,8 @@ func (p *Process) bootstrap() { p.runtime.G = newManualType(eval.TypeOfNative(rt1G{}), p.Arch) // Get addresses of type.*runtime.XType for discrimination. - rtv := reflect.Indirect(reflect.NewValue(&p.runtime)).(*reflect.StructValue) - rtvt := rtv.Type().(*reflect.StructType) + rtv := reflect.Indirect(reflect.ValueOf(&p.runtime)) + rtvt := rtv.Type() for i := 0; i < rtv.NumField(); i++ { n := rtvt.Field(i).Name if n[0] != 'P' || n[1] < 'A' || n[1] > 'Z' { @@ -237,7 +237,7 @@ func (p *Process) bootstrap() { if sym == nil { continue } - rtv.Field(i).(*reflect.UintValue).Set(sym.Value) + rtv.Field(i).SetUint(sym.Value) } // Get runtime field indexes diff --git a/libgo/go/exp/ogle/rruntime.go b/libgo/go/exp/ogle/rruntime.go index 33f1935b89e..950418b5388 100644 --- a/libgo/go/exp/ogle/rruntime.go +++ b/libgo/go/exp/ogle/rruntime.go @@ -236,9 +236,9 @@ type runtimeValues struct { // indexes gathered from the remoteTypes recorded in a runtimeValues // structure. func fillRuntimeIndexes(runtime *runtimeValues, out *runtimeIndexes) { - outv := reflect.Indirect(reflect.NewValue(out)).(*reflect.StructValue) - outt := outv.Type().(*reflect.StructType) - runtimev := reflect.Indirect(reflect.NewValue(runtime)).(*reflect.StructValue) + outv := reflect.Indirect(reflect.ValueOf(out)) + outt := outv.Type() + runtimev := reflect.Indirect(reflect.ValueOf(runtime)) // out contains fields corresponding to each runtime type for i := 0; i < outt.NumField(); i++ { @@ -260,12 +260,12 @@ func fillRuntimeIndexes(runtime *runtimeValues, out *runtimeIndexes) { } // Fill this field of out - outStructv := outv.Field(i).(*reflect.StructValue) - outStructt := outStructv.Type().(*reflect.StructType) + outStructv := outv.Field(i) + outStructt := outStructv.Type() for j := 0; j < outStructt.NumField(); j++ { - f := outStructv.Field(j).(*reflect.IntValue) + f := outStructv.Field(j) name := outStructt.Field(j).Name - f.Set(int64(indexes[name])) + f.SetInt(int64(indexes[name])) } } } diff --git a/libgo/go/exp/wingui/zwinapi.go b/libgo/go/exp/wingui/zwinapi.go index 60aaac6cf16..6ae6330a1fa 100644 --- a/libgo/go/exp/wingui/zwinapi.go +++ b/libgo/go/exp/wingui/zwinapi.go @@ -1,4 +1,4 @@ -// mksyscall_windows.sh winapi.go +// mksyscall_windows.pl winapi.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package main diff --git a/libgo/go/expvar/expvar.go b/libgo/go/expvar/expvar.go index ed6cff78db4..7123d4b0f77 100644 --- a/libgo/go/expvar/expvar.go +++ b/libgo/go/expvar/expvar.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The expvar package provides a standardized interface to public variables, -// such as operation counters in servers. It exposes these variables via -// HTTP at /debug/vars in JSON format. +// Package expvar provides a standardized interface to public variables, such +// as operation counters in servers. It exposes these variables via HTTP at +// /debug/vars in JSON format. // // Operations to set or modify these public variables are atomic. // @@ -180,23 +180,14 @@ func (v *String) String() string { return strconv.Quote(v.s) } func (v *String) Set(value string) { v.s = value } -// IntFunc wraps a func() int64 to create a value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type IntFunc func() int64 +// Func implements Var by calling the function +// and formatting the returned value using JSON. +type Func func() interface{} -func (v IntFunc) String() string { return strconv.Itoa64(v()) } - -// FloatFunc wraps a func() float64 to create a value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type FloatFunc func() float64 - -func (v FloatFunc) String() string { return strconv.Ftoa64(v(), 'g', -1) } - -// StringFunc wraps a func() string to create value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type StringFunc func() string - -func (f StringFunc) String() string { return strconv.Quote(f()) } +func (f Func) String() string { + v, _ := json.Marshal(f()) + return string(v) +} // All published variables. @@ -282,18 +273,16 @@ func expvarHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "\n}\n") } -func memstats() string { - b, _ := json.MarshalIndent(&runtime.MemStats, "", "\t") - return string(b) +func cmdline() interface{} { + return os.Args } -func cmdline() string { - b, _ := json.Marshal(os.Args) - return string(b) +func memstats() interface{} { + return runtime.MemStats } func init() { http.Handle("/debug/vars", http.HandlerFunc(expvarHandler)) - Publish("cmdline", StringFunc(cmdline)) - Publish("memstats", StringFunc(memstats)) + Publish("cmdline", Func(cmdline)) + Publish("memstats", Func(memstats)) } diff --git a/libgo/go/expvar/expvar_test.go b/libgo/go/expvar/expvar_test.go index a8b1a96a93c..94926d9f8ce 100644 --- a/libgo/go/expvar/expvar_test.go +++ b/libgo/go/expvar/expvar_test.go @@ -114,41 +114,15 @@ func TestMapCounter(t *testing.T) { } } -func TestIntFunc(t *testing.T) { - x := int64(4) - ix := IntFunc(func() int64 { return x }) - if s := ix.String(); s != "4" { - t.Errorf("ix.String() = %v, want 4", s) +func TestFunc(t *testing.T) { + var x interface{} = []string{"a", "b"} + f := Func(func() interface{} { return x }) + if s, exp := f.String(), `["a","b"]`; s != exp { + t.Errorf(`f.String() = %q, want %q`, s, exp) } - x++ - if s := ix.String(); s != "5" { - t.Errorf("ix.String() = %v, want 5", s) - } -} - -func TestFloatFunc(t *testing.T) { - x := 8.5 - ix := FloatFunc(func() float64 { return x }) - if s := ix.String(); s != "8.5" { - t.Errorf("ix.String() = %v, want 3.14", s) - } - - x -= 1.25 - if s := ix.String(); s != "7.25" { - t.Errorf("ix.String() = %v, want 4.34", s) - } -} - -func TestStringFunc(t *testing.T) { - x := "hello" - sx := StringFunc(func() string { return x }) - if s, exp := sx.String(), `"hello"`; s != exp { - t.Errorf(`sx.String() = %q, want %q`, s, exp) - } - - x = "goodbye" - if s, exp := sx.String(), `"goodbye"`; s != exp { - t.Errorf(`sx.String() = %q, want %q`, s, exp) + x = 17 + if s, exp := f.String(), `17`; s != exp { + t.Errorf(`f.String() = %q, want %q`, s, exp) } } diff --git a/libgo/go/flag/flag.go b/libgo/go/flag/flag.go index 19a3104553f..9ed20e06b5a 100644 --- a/libgo/go/flag/flag.go +++ b/libgo/go/flag/flag.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The flag package implements command-line flag parsing. + Package flag implements command-line flag parsing. Usage: diff --git a/libgo/go/fmt/doc.go b/libgo/go/fmt/doc.go index 77ee62bb1dd..e4d4f184427 100644 --- a/libgo/go/fmt/doc.go +++ b/libgo/go/fmt/doc.go @@ -27,7 +27,7 @@ %o base 8 %x base 16, with lower-case letters for a-f %X base 16, with upper-case letters for A-F - %U Unicode format: U+1234; same as "U+%x" with 4 digits default + %U Unicode format: U+1234; same as "U+%0.4X" Floating-point and complex constituents: %b decimalless scientific notation with exponent a power of two, in the manner of strconv.Ftoa32, e.g. -123456p-78 diff --git a/libgo/go/fmt/fmt_test.go b/libgo/go/fmt/fmt_test.go index 3766c838a91..b3c0c5abed4 100644 --- a/libgo/go/fmt/fmt_test.go +++ b/libgo/go/fmt/fmt_test.go @@ -170,6 +170,7 @@ var fmttests = []struct { // unicode format {"%U", 0x1, "U+0001"}, + {"%U", uint(0x1), "U+0001"}, {"%.8U", 0x2, "U+00000002"}, {"%U", 0x1234, "U+1234"}, {"%U", 0x12345, "U+12345"}, diff --git a/libgo/go/fmt/print.go b/libgo/go/fmt/print.go index 4b68051188a..10e0fe7c85b 100644 --- a/libgo/go/fmt/print.go +++ b/libgo/go/fmt/print.go @@ -256,11 +256,11 @@ func Sprintln(a ...interface{}) string { // Get the i'th arg of the struct value. // If the arg itself is an interface, return a value for // the thing inside the interface, not the interface itself. -func getField(v *reflect.StructValue, i int) reflect.Value { +func getField(v reflect.Value, i int) reflect.Value { val := v.Field(i) - if i, ok := val.(*reflect.InterfaceValue); ok { + if i := val; i.Kind() == reflect.Interface { if inter := i.Interface(); inter != nil { - return reflect.NewValue(inter) + return reflect.ValueOf(inter) } } return val @@ -278,18 +278,13 @@ func parsenum(s string, start, end int) (num int, isnum bool, newi int) { return } -// Reflection values like reflect.FuncValue implement this method. We use it for %p. -type uintptrGetter interface { - Get() uintptr -} - func (p *pp) unknownType(v interface{}) { if v == nil { p.buf.Write(nilAngleBytes) return } p.buf.WriteByte('?') - p.buf.WriteString(reflect.Typeof(v).String()) + p.buf.WriteString(reflect.TypeOf(v).String()) p.buf.WriteByte('?') } @@ -301,7 +296,7 @@ func (p *pp) badVerb(verb int, val interface{}) { if val == nil { p.buf.Write(nilAngleBytes) } else { - p.buf.WriteString(reflect.Typeof(val).String()) + p.buf.WriteString(reflect.TypeOf(val).String()) p.add('=') p.printField(val, 'v', false, false, 0) } @@ -394,6 +389,8 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { p.fmt.integer(int64(v), 16, unsigned, ldigits) case 'X': p.fmt.integer(int64(v), 16, unsigned, udigits) + case 'U': + p.fmtUnicode(int64(v)) default: p.badVerb(verb, value) } @@ -521,16 +518,16 @@ func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interf func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) { var u uintptr - switch value.(type) { - case *reflect.ChanValue, *reflect.FuncValue, *reflect.MapValue, *reflect.PtrValue, *reflect.SliceValue, *reflect.UnsafePointerValue: - u = value.(uintptrGetter).Get() + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: + u = value.Pointer() default: p.badVerb(verb, field) return } if goSyntax { p.add('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.add(')') p.add('(') if u == 0 { @@ -545,10 +542,10 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } var ( - intBits = reflect.Typeof(0).Bits() - floatBits = reflect.Typeof(0.0).Bits() - complexBits = reflect.Typeof(1i).Bits() - uintptrBits = reflect.Typeof(uintptr(0)).Bits() + intBits = reflect.TypeOf(0).Bits() + floatBits = reflect.TypeOf(0.0).Bits() + complexBits = reflect.TypeOf(1i).Bits() + uintptrBits = reflect.TypeOf(uintptr(0)).Bits() ) func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) { @@ -565,10 +562,10 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth // %T (the value's type) and %p (its address) are special; we always do them first. switch verb { case 'T': - p.printField(reflect.Typeof(field).String(), 's', false, false, 0) + p.printField(reflect.TypeOf(field).String(), 's', false, false, 0) return false case 'p': - p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax) + p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax) return false } // Is it a Formatter? @@ -656,38 +653,38 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth } // Need to use reflection - value := reflect.NewValue(field) + value := reflect.ValueOf(field) BigSwitch: - switch f := value.(type) { - case *reflect.BoolValue: - p.fmtBool(f.Get(), verb, field) - case *reflect.IntValue: - p.fmtInt64(f.Get(), verb, field) - case *reflect.UintValue: - p.fmtUint64(uint64(f.Get()), verb, goSyntax, field) - case *reflect.FloatValue: + switch f := value; f.Kind() { + case reflect.Bool: + p.fmtBool(f.Bool(), verb, field) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.fmtInt64(f.Int(), verb, field) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + p.fmtUint64(uint64(f.Uint()), verb, goSyntax, field) + case reflect.Float32, reflect.Float64: if f.Type().Size() == 4 { - p.fmtFloat32(float32(f.Get()), verb, field) + p.fmtFloat32(float32(f.Float()), verb, field) } else { - p.fmtFloat64(float64(f.Get()), verb, field) + p.fmtFloat64(float64(f.Float()), verb, field) } - case *reflect.ComplexValue: + case reflect.Complex64, reflect.Complex128: if f.Type().Size() == 8 { - p.fmtComplex64(complex64(f.Get()), verb, field) + p.fmtComplex64(complex64(f.Complex()), verb, field) } else { - p.fmtComplex128(complex128(f.Get()), verb, field) + p.fmtComplex128(complex128(f.Complex()), verb, field) } - case *reflect.StringValue: - p.fmtString(f.Get(), verb, goSyntax, field) - case *reflect.MapValue: + case reflect.String: + p.fmtString(f.String(), verb, goSyntax, field) + case reflect.Map: if goSyntax { p.buf.WriteString(f.Type().String()) p.buf.WriteByte('{') } else { p.buf.Write(mapBytes) } - keys := f.Keys() + keys := f.MapKeys() for i, key := range keys { if i > 0 { if goSyntax { @@ -698,20 +695,20 @@ BigSwitch: } p.printField(key.Interface(), verb, plus, goSyntax, depth+1) p.buf.WriteByte(':') - p.printField(f.Elem(key).Interface(), verb, plus, goSyntax, depth+1) + p.printField(f.MapIndex(key).Interface(), verb, plus, goSyntax, depth+1) } if goSyntax { p.buf.WriteByte('}') } else { p.buf.WriteByte(']') } - case *reflect.StructValue: + case reflect.Struct: if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) } p.add('{') v := f - t := v.Type().(*reflect.StructType) + t := v.Type() for i := 0; i < v.NumField(); i++ { if i > 0 { if goSyntax { @@ -729,11 +726,11 @@ BigSwitch: p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1) } p.buf.WriteByte('}') - case *reflect.InterfaceValue: + case reflect.Interface: value := f.Elem() - if value == nil { + if !value.IsValid() { if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.Write(nilParenBytes) } else { p.buf.Write(nilAngleBytes) @@ -741,9 +738,9 @@ BigSwitch: } else { return p.printField(value.Interface(), verb, plus, goSyntax, depth+1) } - case reflect.ArrayOrSliceValue: + case reflect.Array, reflect.Slice: // Byte slices are special. - if f.Type().(reflect.ArrayOrSliceType).Elem().Kind() == reflect.Uint8 { + if f.Type().Elem().Kind() == reflect.Uint8 { // We know it's a slice of bytes, but we also know it does not have static type // []byte, or it would have been caught above. Therefore we cannot convert // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have @@ -753,13 +750,13 @@ BigSwitch: // if reflection could help a little more. bytes := make([]byte, f.Len()) for i := range bytes { - bytes[i] = byte(f.Elem(i).(*reflect.UintValue).Get()) + bytes[i] = byte(f.Index(i).Uint()) } p.fmtBytes(bytes, verb, goSyntax, depth, field) return verb == 's' } if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('{') } else { p.buf.WriteByte('[') @@ -772,24 +769,24 @@ BigSwitch: p.buf.WriteByte(' ') } } - p.printField(f.Elem(i).Interface(), verb, plus, goSyntax, depth+1) + p.printField(f.Index(i).Interface(), verb, plus, goSyntax, depth+1) } if goSyntax { p.buf.WriteByte('}') } else { p.buf.WriteByte(']') } - case *reflect.PtrValue: - v := f.Get() + case reflect.Ptr: + v := f.Pointer() // pointer to array or slice or struct? ok at top level // but not embedded (avoid loops) if v != 0 && depth == 0 { - switch a := f.Elem().(type) { - case reflect.ArrayOrSliceValue: + switch a := f.Elem(); a.Kind() { + case reflect.Array, reflect.Slice: p.buf.WriteByte('&') p.printField(a.Interface(), verb, plus, goSyntax, depth+1) break BigSwitch - case *reflect.StructValue: + case reflect.Struct: p.buf.WriteByte('&') p.printField(a.Interface(), verb, plus, goSyntax, depth+1) break BigSwitch @@ -797,7 +794,7 @@ BigSwitch: } if goSyntax { p.buf.WriteByte('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte(')') p.buf.WriteByte('(') if v == 0 { @@ -813,7 +810,7 @@ BigSwitch: break } p.fmt0x64(uint64(v), true) - case *reflect.ChanValue, *reflect.FuncValue, *reflect.UnsafePointerValue: + case reflect.Chan, reflect.Func, reflect.UnsafePointer: p.fmtPointer(field, value, verb, goSyntax) default: p.unknownType(f) @@ -918,7 +915,7 @@ func (p *pp) doPrintf(format string, a []interface{}) { for ; fieldnum < len(a); fieldnum++ { field := a[fieldnum] if field != nil { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('=') } p.printField(field, 'v', false, false, 0) @@ -937,7 +934,7 @@ func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) { // always add spaces if we're doing println field := a[fieldnum] if fieldnum > 0 { - isString := field != nil && reflect.Typeof(field).Kind() == reflect.String + isString := field != nil && reflect.TypeOf(field).Kind() == reflect.String if addspace || !isString && !prevString { p.buf.WriteByte(' ') } diff --git a/libgo/go/fmt/scan.go b/libgo/go/fmt/scan.go index 36271a8d466..42bc52c92bc 100644 --- a/libgo/go/fmt/scan.go +++ b/libgo/go/fmt/scan.go @@ -423,7 +423,7 @@ func (s *ss) token(skipSpace bool, f func(int) bool) []byte { // typeError indicates that the type of the operand did not match the format func (s *ss) typeError(field interface{}, expected string) { - s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String()) + s.errorString("expected field of type pointer to " + expected + "; found " + reflect.TypeOf(field).String()) } var complexError = os.ErrorString("syntax error scanning complex number") @@ -908,37 +908,37 @@ func (s *ss) scanOne(verb int, field interface{}) { // If we scanned to bytes, the slice would point at the buffer. *v = []byte(s.convertString(verb)) default: - val := reflect.NewValue(v) - ptr, ok := val.(*reflect.PtrValue) - if !ok { + val := reflect.ValueOf(v) + ptr := val + if ptr.Kind() != reflect.Ptr { s.errorString("Scan: type not a pointer: " + val.Type().String()) return } - switch v := ptr.Elem().(type) { - case *reflect.BoolValue: - v.Set(s.scanBool(verb)) - case *reflect.IntValue: - v.Set(s.scanInt(verb, v.Type().Bits())) - case *reflect.UintValue: - v.Set(s.scanUint(verb, v.Type().Bits())) - case *reflect.StringValue: - v.Set(s.convertString(verb)) - case *reflect.SliceValue: + switch v := ptr.Elem(); v.Kind() { + case reflect.Bool: + v.SetBool(s.scanBool(verb)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.SetInt(s.scanInt(verb, v.Type().Bits())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + v.SetUint(s.scanUint(verb, v.Type().Bits())) + case reflect.String: + v.SetString(s.convertString(verb)) + case reflect.Slice: // For now, can only handle (renamed) []byte. - typ := v.Type().(*reflect.SliceType) + typ := v.Type() if typ.Elem().Kind() != reflect.Uint8 { goto CantHandle } str := s.convertString(verb) v.Set(reflect.MakeSlice(typ, len(str), len(str))) for i := 0; i < len(str); i++ { - v.Elem(i).(*reflect.UintValue).Set(uint64(str[i])) + v.Index(i).SetUint(uint64(str[i])) } - case *reflect.FloatValue: + case reflect.Float32, reflect.Float64: s.skipSpace(false) - v.Set(s.convertFloat(s.floatToken(), v.Type().Bits())) - case *reflect.ComplexValue: - v.Set(s.scanComplex(verb, v.Type().Bits())) + v.SetFloat(s.convertFloat(s.floatToken(), v.Type().Bits())) + case reflect.Complex64, reflect.Complex128: + v.SetComplex(s.scanComplex(verb, v.Type().Bits())) default: CantHandle: s.errorString("Scan: can't handle type: " + val.Type().String()) diff --git a/libgo/go/fmt/scan_test.go b/libgo/go/fmt/scan_test.go index 8eb3e5bfbb0..da13eb2d112 100644 --- a/libgo/go/fmt/scan_test.go +++ b/libgo/go/fmt/scan_test.go @@ -370,8 +370,8 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{} continue } // The incoming value may be a pointer - v := reflect.NewValue(test.in) - if p, ok := v.(*reflect.PtrValue); ok { + v := reflect.ValueOf(test.in) + if p := v; p.Kind() == reflect.Ptr { v = p.Elem() } val := v.Interface() @@ -409,8 +409,8 @@ func TestScanf(t *testing.T) { continue } // The incoming value may be a pointer - v := reflect.NewValue(test.in) - if p, ok := v.(*reflect.PtrValue); ok { + v := reflect.ValueOf(test.in) + if p := v; p.Kind() == reflect.Ptr { v = p.Elem() } val := v.Interface() @@ -486,7 +486,7 @@ func TestInf(t *testing.T) { } func testScanfMulti(name string, t *testing.T) { - sliceType := reflect.Typeof(make([]interface{}, 1)).(*reflect.SliceType) + sliceType := reflect.TypeOf(make([]interface{}, 1)) for _, test := range multiTests { var r io.Reader if name == "StringReader" { @@ -513,8 +513,8 @@ func testScanfMulti(name string, t *testing.T) { // Convert the slice of pointers into a slice of values resultVal := reflect.MakeSlice(sliceType, n, n) for i := 0; i < n; i++ { - v := reflect.NewValue(test.in[i]).(*reflect.PtrValue).Elem() - resultVal.Elem(i).(*reflect.InterfaceValue).Set(v) + v := reflect.ValueOf(test.in[i]).Elem() + resultVal.Index(i).Set(v) } result := resultVal.Interface() if !reflect.DeepEqual(result, test.out) { diff --git a/libgo/go/go/ast/ast.go b/libgo/go/go/ast/ast.go index ed3e2cdd9be..2fc1a60323d 100644 --- a/libgo/go/go/ast/ast.go +++ b/libgo/go/go/ast/ast.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The AST package declares the types used to represent -// syntax trees for Go packages. +// Package ast declares the types used to represent syntax trees for Go +// packages. // package ast diff --git a/libgo/go/go/ast/print.go b/libgo/go/go/ast/print.go index 82c334ece67..81e1da1d0aa 100644 --- a/libgo/go/go/ast/print.go +++ b/libgo/go/go/ast/print.go @@ -21,11 +21,12 @@ type FieldFilter func(name string, value reflect.Value) bool // NotNilFilter returns true for field values that are not nil; // it returns false otherwise. -func NotNilFilter(_ string, value reflect.Value) bool { - v, ok := value.(interface { - IsNil() bool - }) - return !ok || !v.IsNil() +func NotNilFilter(_ string, v reflect.Value) bool { + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return !v.IsNil() + } + return true } @@ -61,7 +62,7 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i p.printf("nil\n") return } - p.print(reflect.NewValue(x)) + p.print(reflect.ValueOf(x)) p.printf("\n") return @@ -79,7 +80,7 @@ type printer struct { output io.Writer fset *token.FileSet filter FieldFilter - ptrmap map[interface{}]int // *reflect.PtrValue -> line number + ptrmap map[interface{}]int // *T -> line number written int // number of bytes written to output indent int // current indentation level last byte // the last byte processed by Write @@ -140,6 +141,11 @@ func (p *printer) printf(format string, args ...interface{}) { // Implementation note: Print is written for AST nodes but could be // used to print arbitrary data structures; such a version should // probably be in a different package. +// +// Note: This code detects (some) cycles created via pointers but +// not cycles that are created via slices or maps containing the +// same slice or map. Code for general data structures probably +// should catch those as well. func (p *printer) print(x reflect.Value) { if !NotNilFilter("", x) { @@ -147,57 +153,57 @@ func (p *printer) print(x reflect.Value) { return } - switch v := x.(type) { - case *reflect.InterfaceValue: - p.print(v.Elem()) + switch x.Kind() { + case reflect.Interface: + p.print(x.Elem()) - case *reflect.MapValue: - p.printf("%s (len = %d) {\n", x.Type().String(), v.Len()) + case reflect.Map: + p.printf("%s (len = %d) {\n", x.Type().String(), x.Len()) p.indent++ - for _, key := range v.Keys() { + for _, key := range x.MapKeys() { p.print(key) p.printf(": ") - p.print(v.Elem(key)) + p.print(x.MapIndex(key)) p.printf("\n") } p.indent-- p.printf("}") - case *reflect.PtrValue: + case reflect.Ptr: p.printf("*") // type-checked ASTs may contain cycles - use ptrmap // to keep track of objects that have been printed // already and print the respective line number instead - ptr := v.Interface() + ptr := x.Interface() if line, exists := p.ptrmap[ptr]; exists { p.printf("(obj @ %d)", line) } else { p.ptrmap[ptr] = p.line - p.print(v.Elem()) + p.print(x.Elem()) } - case *reflect.SliceValue: - if s, ok := v.Interface().([]byte); ok { + case reflect.Slice: + if s, ok := x.Interface().([]byte); ok { p.printf("%#q", s) return } - p.printf("%s (len = %d) {\n", x.Type().String(), v.Len()) + p.printf("%s (len = %d) {\n", x.Type().String(), x.Len()) p.indent++ - for i, n := 0, v.Len(); i < n; i++ { + for i, n := 0, x.Len(); i < n; i++ { p.printf("%d: ", i) - p.print(v.Elem(i)) + p.print(x.Index(i)) p.printf("\n") } p.indent-- p.printf("}") - case *reflect.StructValue: + case reflect.Struct: p.printf("%s {\n", x.Type().String()) p.indent++ - t := v.Type().(*reflect.StructType) + t := x.Type() for i, n := 0, t.NumField(); i < n; i++ { name := t.Field(i).Name - value := v.Field(i) + value := x.Field(i) if p.filter == nil || p.filter(name, value) { p.printf("%s: ", name) p.print(value) @@ -208,11 +214,20 @@ func (p *printer) print(x reflect.Value) { p.printf("}") default: - value := x.Interface() - // position values can be printed nicely if we have a file set - if pos, ok := value.(token.Pos); ok && p.fset != nil { - value = p.fset.Position(pos) + v := x.Interface() + switch v := v.(type) { + case string: + // print strings in quotes + p.printf("%q", v) + return + case token.Pos: + // position values can be printed nicely if we have a file set + if p.fset != nil { + p.printf("%s", p.fset.Position(v)) + return + } } - p.printf("%v", value) + // default + p.printf("%v", v) } } diff --git a/libgo/go/go/ast/print_test.go b/libgo/go/go/ast/print_test.go new file mode 100644 index 00000000000..0820dcfcef2 --- /dev/null +++ b/libgo/go/go/ast/print_test.go @@ -0,0 +1,80 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ast + +import ( + "bytes" + "strings" + "testing" +) + + +var tests = []struct { + x interface{} // x is printed as s + s string +}{ + // basic types + {nil, "0 nil"}, + {true, "0 true"}, + {42, "0 42"}, + {3.14, "0 3.14"}, + {1 + 2.718i, "0 (1+2.718i)"}, + {"foobar", "0 \"foobar\""}, + + // maps + {map[string]int{"a": 1, "b": 2}, + `0 map[string] int (len = 2) { + 1 . "a": 1 + 2 . "b": 2 + 3 }`}, + + // pointers + {new(int), "0 *0"}, + + // slices + {[]int{1, 2, 3}, + `0 []int (len = 3) { + 1 . 0: 1 + 2 . 1: 2 + 3 . 2: 3 + 4 }`}, + + // structs + {struct{ x, y int }{42, 991}, + `0 struct { x int; y int } { + 1 . x: 42 + 2 . y: 991 + 3 }`}, +} + + +// Split s into lines, trim whitespace from all lines, and return +// the concatenated non-empty lines. +func trim(s string) string { + lines := strings.Split(s, "\n", -1) + i := 0 + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" { + lines[i] = line + i++ + } + } + return strings.Join(lines[0:i], "\n") +} + + +func TestPrint(t *testing.T) { + var buf bytes.Buffer + for _, test := range tests { + buf.Reset() + if _, err := Fprint(&buf, nil, test.x, nil); err != nil { + t.Errorf("Fprint failed: %s", err) + } + if s, ts := trim(buf.String()), trim(test.s); s != ts { + t.Errorf("got:\n%s\nexpected:\n%s\n", s, ts) + } + } +} diff --git a/libgo/go/go/doc/doc.go b/libgo/go/go/doc/doc.go index e7a8d3f63bb..29d205d391c 100644 --- a/libgo/go/go/doc/doc.go +++ b/libgo/go/go/doc/doc.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The doc package extracts source code documentation from a Go AST. +// Package doc extracts source code documentation from a Go AST. package doc import ( diff --git a/libgo/go/go/parser/parser.go b/libgo/go/go/parser/parser.go index 84a0da6ae7b..afa9ae517b6 100644 --- a/libgo/go/go/parser/parser.go +++ b/libgo/go/go/parser/parser.go @@ -2,10 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A parser for Go source files. Input may be provided in a variety of -// forms (see the various Parse* functions); the output is an abstract -// syntax tree (AST) representing the Go source. The parser is invoked -// through one of the Parse* functions. +// Package parser implements a parser for Go source files. Input may be +// provided in a variety of forms (see the various Parse* functions); the +// output is an abstract syntax tree (AST) representing the Go source. The +// parser is invoked through one of the Parse* functions. // package parser @@ -1780,10 +1780,6 @@ func (p *parser) parseCommClause() *ast.CommClause { rhs = lhs[0] lhs = nil // there is no lhs } - if x, isUnary := rhs.(*ast.UnaryExpr); !isUnary || x.Op != token.ARROW { - p.errorExpected(rhs.Pos(), "send or receive operation") - rhs = &ast.BadExpr{rhs.Pos(), rhs.End()} - } if lhs != nil { comm = &ast.AssignStmt{lhs, pos, tok, []ast.Expr{rhs}} } else { diff --git a/libgo/go/go/parser/parser_test.go b/libgo/go/go/parser/parser_test.go index 2f1ee6bfc09..5b52f51d4a5 100644 --- a/libgo/go/go/parser/parser_test.go +++ b/libgo/go/go/parser/parser_test.go @@ -51,6 +51,7 @@ var validPrograms = []interface{}{ `package p; type T []int; func f() { for _ = range []int{T{42}[0]} {} };`, `package p; var a = T{{1, 2}, {3, 4}}`, `package p; func f() { select { case <- c: case c <- d: case c <- <- d: case <-c <- d: } };`, + `package p; func f() { select { case x := (<-c): } };`, `package p; func f() { if ; true {} };`, `package p; func f() { switch ; {} };`, } diff --git a/libgo/go/go/printer/nodes.go b/libgo/go/go/printer/nodes.go index 0b3b6621e6c..86c32793062 100644 --- a/libgo/go/go/printer/nodes.go +++ b/libgo/go/go/printer/nodes.go @@ -1200,7 +1200,7 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool, multiLine *bool) { p.setComment(s.Doc) if s.Name != nil { p.expr(s.Name, multiLine) - p.print(vtab) + p.print(blank) } p.expr(s.Path, multiLine) p.setComment(s.Comment) diff --git a/libgo/go/go/printer/performance_test.go b/libgo/go/go/printer/performance_test.go new file mode 100644 index 00000000000..31de0b7ad40 --- /dev/null +++ b/libgo/go/go/printer/performance_test.go @@ -0,0 +1,62 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements a simple printer performance benchmark: +// gotest -bench=BenchmarkPrint + +package printer + +import ( + "bytes" + "go/ast" + "go/parser" + "io" + "io/ioutil" + "log" + "testing" +) + + +var testfile *ast.File + + +func testprint(out io.Writer, file *ast.File) { + if _, err := (&Config{TabIndent | UseSpaces, 8}).Fprint(out, fset, file); err != nil { + log.Fatalf("print error: %s", err) + } +} + + +// cannot initialize in init because (printer) Fprint launches goroutines. +func initialize() { + const filename = "testdata/parser.go" + + src, err := ioutil.ReadFile(filename) + if err != nil { + log.Fatalf("%s", err) + } + + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments) + if err != nil { + log.Fatalf("%s", err) + } + + var buf bytes.Buffer + testprint(&buf, file) + if !bytes.Equal(buf.Bytes(), src) { + log.Fatalf("print error: %s not idempotent", filename) + } + + testfile = file +} + + +func BenchmarkPrint(b *testing.B) { + if testfile == nil { + initialize() + } + for i := 0; i < b.N; i++ { + testprint(ioutil.Discard, testfile) + } +} diff --git a/libgo/go/go/printer/printer.go b/libgo/go/go/printer/printer.go index 697a83fa866..01ebf783c41 100644 --- a/libgo/go/go/printer/printer.go +++ b/libgo/go/go/printer/printer.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The printer package implements printing of AST nodes. +// Package printer implements printing of AST nodes. package printer import ( diff --git a/libgo/go/go/printer/testdata/declarations.golden b/libgo/go/go/printer/testdata/declarations.golden index 1c091b9295e..c1b255842c1 100644 --- a/libgo/go/go/printer/testdata/declarations.golden +++ b/libgo/go/go/printer/testdata/declarations.golden @@ -7,10 +7,10 @@ package imports import "io" import ( - _ "io" + _ "io" ) -import _ "io" +import _ "io" import ( "io" @@ -20,40 +20,40 @@ import ( import ( "io" - aLongRename "io" + aLongRename "io" - b "io" + b "io" ) import ( "unrenamed" - renamed "renameMe" - . "io" - _ "io" + renamed "renameMe" + . "io" + _ "io" "io" - . "os" + . "os" ) // no newlines between consecutive single imports, but // respect extra line breaks in the source (at most one empty line) -import _ "io" -import _ "io" -import _ "io" +import _ "io" +import _ "io" +import _ "io" -import _ "os" -import _ "os" -import _ "os" +import _ "os" +import _ "os" +import _ "os" -import _ "fmt" -import _ "fmt" -import _ "fmt" +import _ "fmt" +import _ "fmt" +import _ "fmt" import "foo" // a comment import "bar" // a comment import ( - _ "foo" + _ "foo" // a comment "bar" "foo" // a comment @@ -63,17 +63,17 @@ import ( // comments + renames import ( "unrenamed" // a comment - renamed "renameMe" - . "io" /* a comment */ - _ "io/ioutil" // a comment + renamed "renameMe" + . "io" /* a comment */ + _ "io/ioutil" // a comment "io" // testing alignment - . "os" + . "os" // a comment ) // a case that caused problems in the past (comment placement) import ( - . "fmt" + . "fmt" "io" "malloc" // for the malloc count test only "math" @@ -81,9 +81,38 @@ import ( "testing" ) +// more import examples +import ( + "xxx" + "much longer name" // comment + "short name" // comment +) + +import ( + _ "xxx" + "much longer name" // comment +) + +import ( + mymath "math" + "/foo/bar/long_package_path" // a comment +) + +import ( + "package_a" // comment + "package_b" + my_better_c "package_c" // comment + "package_d" // comment + my_e "package_e" // comment + + "package_a" // comment + "package_bb" + "package_ccc" // comment + "package_dddd" // comment +) // at least one empty line between declarations of different kind -import _ "io" +import _ "io" var _ int diff --git a/libgo/go/go/printer/testdata/declarations.input b/libgo/go/go/printer/testdata/declarations.input index c826462f9dc..c8b37e12ba4 100644 --- a/libgo/go/go/printer/testdata/declarations.input +++ b/libgo/go/go/printer/testdata/declarations.input @@ -81,6 +81,35 @@ import ( "testing" ) +// more import examples +import ( + "xxx" + "much longer name" // comment + "short name" // comment +) + +import ( + _ "xxx" + "much longer name" // comment +) + +import ( + mymath "math" + "/foo/bar/long_package_path" // a comment +) + +import ( + "package_a" // comment + "package_b" + my_better_c "package_c" // comment + "package_d" // comment + my_e "package_e" // comment + + "package_a" // comment + "package_bb" + "package_ccc" // comment + "package_dddd" // comment +) // at least one empty line between declarations of different kind import _ "io" diff --git a/libgo/go/go/printer/testdata/expressions.golden b/libgo/go/go/printer/testdata/expressions.golden index c1a7e970b45..3d0f144e10f 100644 --- a/libgo/go/go/printer/testdata/expressions.golden +++ b/libgo/go/go/printer/testdata/expressions.golden @@ -94,30 +94,49 @@ func _() { _ = under_bar - 1 _ = Open(dpath+"/file", O_WRONLY|O_CREAT, 0666) _ = int(c0&_Mask4)<<18 | int(c1&_Maskx)<<12 | int(c2&_Maskx)<<6 | int(c3&_Maskx) -} - -func _() { + // the parser does not restrict expressions that may appear as statements + true + 42 + "foo" + x + (x) a + b a + b + c - a + b*c a + (b * c) - (a + b) * c - a + (b * c * d) - a + (b*c + d) + a + (b / c) + 1 + a + a + 1 + s[a] + x << 1 + (s[0] << 1) & 0xf + "foo" + s + x == y + x < y || z > 42 +} + + +func _() { + _ = a + b + _ = a + b + c + _ = a + b*c + _ = a + (b * c) + _ = (a + b) * c + _ = a + (b * c * d) + _ = a + (b*c + d) - 1 << x - -1 << x - 1< 0 && i >= 0 + _ = x > 0 && i >= 0 x1, x0 := x>>w2, x&m2 z0 = t1<> (uint(w) - z)) x1 = x1<>(uint(w)-z) - buf[0 : len(buf)+1] - buf[0 : n+1] + _ = buf[0 : len(buf)+1] + _ = buf[0 : n+1] a, b = b, a a = b + c a = b*c + d - a*b + c - a - b - c - a - (b - c) - a - b*c - a - (b * c) - a * b / c - a / *b - x[a|^b] - x[a / *b] - a & ^b - a + +b - a - -b - x[a*-b] - x[a + +b] - x ^ y ^ z - b[a>>24] ^ b[(a>>16)&0xFF] ^ b[(a>>8)&0xFF] ^ b[a&0xFF] - len(longVariableName) * 2 - - token(matchType + xlength<>24] ^ b[(a>>16)&0xFF] ^ b[(a>>8)&0xFF] ^ b[a&0xFF] + _ = len(longVariableName) * 2 + + _ = token(matchType + xlength< 42 +} - 1< 0 && i >= 0 + _ = x > 0 && i >= 0 x1, x0 := x>>w2, x&m2 z0 = t1<>(uint(w)-z)) x1 = x1<>(uint(w)-z) - buf[0:len(buf)+1] - buf[0:n+1] + _ = buf[0:len(buf)+1] + _ = buf[0:n+1] a,b = b,a a = b+c a = b*c+d - a*b+c - a-b-c - a-(b-c) - a-b*c - a-(b*c) - a*b/c - a/ *b - x[a|^b] - x[a/ *b] - a& ^b - a+ +b - a- -b - x[a*-b] - x[a+ +b] - x^y^z - b[a>>24] ^ b[(a>>16)&0xFF] ^ b[(a>>8)&0xFF] ^ b[a&0xFF] - len(longVariableName)*2 - - token(matchType + xlength<>24] ^ b[(a>>16)&0xFF] ^ b[(a>>8)&0xFF] ^ b[a&0xFF] + _ = len(longVariableName)*2 + + _ = token(matchType + xlength< 42 +} + + +func _() { + _ = a + b + _ = a + b + c + _ = a + b*c + _ = a + (b * c) + _ = (a + b) * c + _ = a + (b * c * d) + _ = a + (b*c + d) - 1 << x - -1 << x - 1< 0 && i >= 0 + _ = x > 0 && i >= 0 x1, x0 := x>>w2, x&m2 z0 = t1<> (uint(w) - z)) x1 = x1<>(uint(w)-z) - buf[0 : len(buf)+1] - buf[0 : n+1] + _ = buf[0 : len(buf)+1] + _ = buf[0 : n+1] a, b = b, a a = b + c a = b*c + d - a*b + c - a - b - c - a - (b - c) - a - b*c - a - (b * c) - a * b / c - a / *b - x[a|^b] - x[a / *b] - a & ^b - a + +b - a - -b - x[a*-b] - x[a + +b] - x ^ y ^ z - b[a>>24] ^ b[(a>>16)&0xFF] ^ b[(a>>8)&0xFF] ^ b[a&0xFF] - len(longVariableName) * 2 - - token(matchType + xlength<>24] ^ b[(a>>16)&0xFF] ^ b[(a>>8)&0xFF] ^ b[a&0xFF] + _ = len(longVariableName) * 2 + + _ = token(matchType + xlength<= 0: in expression + + // Ordinary identifer scopes + pkgScope *ast.Scope // pkgScope.Outer == nil + topScope *ast.Scope // top-most scope; may be pkgScope + unresolved []*ast.Ident // unresolved identifiers + imports []*ast.ImportSpec // list of imports + + // Label scope + // (maintained by open/close LabelScope) + labelScope *ast.Scope // label scope for current function + targetStack [][]*ast.Ident // stack of unresolved labels +} + + +// scannerMode returns the scanner mode bits given the parser's mode bits. +func scannerMode(mode uint) uint { + var m uint = scanner.InsertSemis + if mode&ParseComments != 0 { + m |= scanner.ScanComments + } + return m +} + + +func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode uint) { + p.file = fset.AddFile(filename, fset.Base(), len(src)) + p.scanner.Init(p.file, src, p, scannerMode(mode)) + + p.mode = mode + p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently) + + p.next() + + // set up the pkgScope here (as opposed to in parseFile) because + // there are other parser entry points (ParseExpr, etc.) + p.openScope() + p.pkgScope = p.topScope + + // for the same reason, set up a label scope + p.openLabelScope() +} + + +// ---------------------------------------------------------------------------- +// Scoping support + +func (p *parser) openScope() { + p.topScope = ast.NewScope(p.topScope) +} + + +func (p *parser) closeScope() { + p.topScope = p.topScope.Outer +} + + +func (p *parser) openLabelScope() { + p.labelScope = ast.NewScope(p.labelScope) + p.targetStack = append(p.targetStack, nil) +} + + +func (p *parser) closeLabelScope() { + // resolve labels + n := len(p.targetStack) - 1 + scope := p.labelScope + for _, ident := range p.targetStack[n] { + ident.Obj = scope.Lookup(ident.Name) + if ident.Obj == nil && p.mode&DeclarationErrors != 0 { + p.error(ident.Pos(), fmt.Sprintf("label %s undefined", ident.Name)) + } + } + // pop label scope + p.targetStack = p.targetStack[0:n] + p.labelScope = p.labelScope.Outer +} + + +func (p *parser) declare(decl interface{}, scope *ast.Scope, kind ast.ObjKind, idents ...*ast.Ident) { + for _, ident := range idents { + assert(ident.Obj == nil, "identifier already declared or resolved") + if ident.Name != "_" { + obj := ast.NewObj(kind, ident.Name) + // remember the corresponding declaration for redeclaration + // errors and global variable resolution/typechecking phase + obj.Decl = decl + if alt := scope.Insert(obj); alt != nil && p.mode&DeclarationErrors != 0 { + prevDecl := "" + if pos := alt.Pos(); pos.IsValid() { + prevDecl = fmt.Sprintf("\n\tprevious declaration at %s", p.file.Position(pos)) + } + p.error(ident.Pos(), fmt.Sprintf("%s redeclared in this block%s", ident.Name, prevDecl)) + } + ident.Obj = obj + } + } +} + + +func (p *parser) shortVarDecl(idents []*ast.Ident) { + // Go spec: A short variable declaration may redeclare variables + // provided they were originally declared in the same block with + // the same type, and at least one of the non-blank variables is new. + n := 0 // number of new variables + for _, ident := range idents { + assert(ident.Obj == nil, "identifier already declared or resolved") + if ident.Name != "_" { + obj := ast.NewObj(ast.Var, ident.Name) + // short var declarations cannot have redeclaration errors + // and are not global => no need to remember the respective + // declaration + alt := p.topScope.Insert(obj) + if alt == nil { + n++ // new declaration + alt = obj + } + ident.Obj = alt + } + } + if n == 0 && p.mode&DeclarationErrors != 0 { + p.error(idents[0].Pos(), "no new variables on left side of :=") + } +} + + +// The unresolved object is a sentinel to mark identifiers that have been added +// to the list of unresolved identifiers. The sentinel is only used for verifying +// internal consistency. +var unresolved = new(ast.Object) + + +func (p *parser) resolve(x ast.Expr) { + // nothing to do if x is not an identifier or the blank identifier + ident, _ := x.(*ast.Ident) + if ident == nil { + return + } + assert(ident.Obj == nil, "identifier already declared or resolved") + if ident.Name == "_" { + return + } + // try to resolve the identifier + for s := p.topScope; s != nil; s = s.Outer { + if obj := s.Lookup(ident.Name); obj != nil { + ident.Obj = obj + return + } + } + // all local scopes are known, so any unresolved identifier + // must be found either in the file scope, package scope + // (perhaps in another file), or universe scope --- collect + // them so that they can be resolved later + ident.Obj = unresolved + p.unresolved = append(p.unresolved, ident) +} + + +// ---------------------------------------------------------------------------- +// Parsing support + +func (p *parser) printTrace(a ...interface{}) { + const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " + + ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " + const n = uint(len(dots)) + pos := p.file.Position(p.pos) + fmt.Printf("%5d:%3d: ", pos.Line, pos.Column) + i := 2 * p.indent + for ; i > n; i -= n { + fmt.Print(dots) + } + fmt.Print(dots[0:i]) + fmt.Println(a...) +} + + +func trace(p *parser, msg string) *parser { + p.printTrace(msg, "(") + p.indent++ + return p +} + + +// Usage pattern: defer un(trace(p, "...")); +func un(p *parser) { + p.indent-- + p.printTrace(")") +} + + +// Advance to the next token. +func (p *parser) next0() { + // Because of one-token look-ahead, print the previous token + // when tracing as it provides a more readable output. The + // very first token (!p.pos.IsValid()) is not initialized + // (it is token.ILLEGAL), so don't print it . + if p.trace && p.pos.IsValid() { + s := p.tok.String() + switch { + case p.tok.IsLiteral(): + p.printTrace(s, p.lit) + case p.tok.IsOperator(), p.tok.IsKeyword(): + p.printTrace("\"" + s + "\"") + default: + p.printTrace(s) + } + } + + p.pos, p.tok, p.lit = p.scanner.Scan() +} + +// Consume a comment and return it and the line on which it ends. +func (p *parser) consumeComment() (comment *ast.Comment, endline int) { + // /*-style comments may end on a different line than where they start. + // Scan the comment for '\n' chars and adjust endline accordingly. + endline = p.file.Line(p.pos) + if p.lit[1] == '*' { + // don't use range here - no need to decode Unicode code points + for i := 0; i < len(p.lit); i++ { + if p.lit[i] == '\n' { + endline++ + } + } + } + + comment = &ast.Comment{p.pos, p.lit} + p.next0() + + return +} + + +// Consume a group of adjacent comments, add it to the parser's +// comments list, and return it together with the line at which +// the last comment in the group ends. An empty line or non-comment +// token terminates a comment group. +// +func (p *parser) consumeCommentGroup() (comments *ast.CommentGroup, endline int) { + var list []*ast.Comment + endline = p.file.Line(p.pos) + for p.tok == token.COMMENT && endline+1 >= p.file.Line(p.pos) { + var comment *ast.Comment + comment, endline = p.consumeComment() + list = append(list, comment) + } + + // add comment group to the comments list + comments = &ast.CommentGroup{list} + p.comments = append(p.comments, comments) + + return +} + + +// Advance to the next non-comment token. In the process, collect +// any comment groups encountered, and remember the last lead and +// and line comments. +// +// A lead comment is a comment group that starts and ends in a +// line without any other tokens and that is followed by a non-comment +// token on the line immediately after the comment group. +// +// A line comment is a comment group that follows a non-comment +// token on the same line, and that has no tokens after it on the line +// where it ends. +// +// Lead and line comments may be considered documentation that is +// stored in the AST. +// +func (p *parser) next() { + p.leadComment = nil + p.lineComment = nil + line := p.file.Line(p.pos) // current line + p.next0() + + if p.tok == token.COMMENT { + var comment *ast.CommentGroup + var endline int + + if p.file.Line(p.pos) == line { + // The comment is on same line as the previous token; it + // cannot be a lead comment but may be a line comment. + comment, endline = p.consumeCommentGroup() + if p.file.Line(p.pos) != endline { + // The next token is on a different line, thus + // the last comment group is a line comment. + p.lineComment = comment + } + } + + // consume successor comments, if any + endline = -1 + for p.tok == token.COMMENT { + comment, endline = p.consumeCommentGroup() + } + + if endline+1 == p.file.Line(p.pos) { + // The next token is following on the line immediately after the + // comment group, thus the last comment group is a lead comment. + p.leadComment = comment + } + } +} + + +func (p *parser) error(pos token.Pos, msg string) { + p.Error(p.file.Position(pos), msg) +} + + +func (p *parser) errorExpected(pos token.Pos, msg string) { + msg = "expected " + msg + if pos == p.pos { + // the error happened at the current position; + // make the error message more specific + if p.tok == token.SEMICOLON && p.lit[0] == '\n' { + msg += ", found newline" + } else { + msg += ", found '" + p.tok.String() + "'" + if p.tok.IsLiteral() { + msg += " " + p.lit + } + } + } + p.error(pos, msg) +} + + +func (p *parser) expect(tok token.Token) token.Pos { + pos := p.pos + if p.tok != tok { + p.errorExpected(pos, "'"+tok.String()+"'") + } + p.next() // make progress + return pos +} + + +func (p *parser) expectSemi() { + if p.tok != token.RPAREN && p.tok != token.RBRACE { + p.expect(token.SEMICOLON) + } +} + + +func assert(cond bool, msg string) { + if !cond { + panic("go/parser internal error: " + msg) + } +} + + +// ---------------------------------------------------------------------------- +// Identifiers + +func (p *parser) parseIdent() *ast.Ident { + pos := p.pos + name := "_" + if p.tok == token.IDENT { + name = p.lit + p.next() + } else { + p.expect(token.IDENT) // use expect() error handling + } + return &ast.Ident{pos, name, nil} +} + + +func (p *parser) parseIdentList() (list []*ast.Ident) { + if p.trace { + defer un(trace(p, "IdentList")) + } + + list = append(list, p.parseIdent()) + for p.tok == token.COMMA { + p.next() + list = append(list, p.parseIdent()) + } + + return +} + + +// ---------------------------------------------------------------------------- +// Common productions + +// If lhs is set, result list elements which are identifiers are not resolved. +func (p *parser) parseExprList(lhs bool) (list []ast.Expr) { + if p.trace { + defer un(trace(p, "ExpressionList")) + } + + list = append(list, p.parseExpr(lhs)) + for p.tok == token.COMMA { + p.next() + list = append(list, p.parseExpr(lhs)) + } + + return +} + + +func (p *parser) parseLhsList() []ast.Expr { + list := p.parseExprList(true) + switch p.tok { + case token.DEFINE: + // lhs of a short variable declaration + p.shortVarDecl(p.makeIdentList(list)) + case token.COLON: + // lhs of a label declaration or a communication clause of a select + // statement (parseLhsList is not called when parsing the case clause + // of a switch statement): + // - labels are declared by the caller of parseLhsList + // - for communication clauses, if there is a stand-alone identifier + // followed by a colon, we have a syntax error; there is no need + // to resolve the identifier in that case + default: + // identifiers must be declared elsewhere + for _, x := range list { + p.resolve(x) + } + } + return list +} + + +func (p *parser) parseRhsList() []ast.Expr { + return p.parseExprList(false) +} + + +// ---------------------------------------------------------------------------- +// Types + +func (p *parser) parseType() ast.Expr { + if p.trace { + defer un(trace(p, "Type")) + } + + typ := p.tryType() + + if typ == nil { + pos := p.pos + p.errorExpected(pos, "type") + p.next() // make progress + return &ast.BadExpr{pos, p.pos} + } + + return typ +} + + +// If the result is an identifier, it is not resolved. +func (p *parser) parseTypeName() ast.Expr { + if p.trace { + defer un(trace(p, "TypeName")) + } + + ident := p.parseIdent() + // don't resolve ident yet - it may be a parameter or field name + + if p.tok == token.PERIOD { + // ident is a package name + p.next() + p.resolve(ident) + sel := p.parseIdent() + return &ast.SelectorExpr{ident, sel} + } + + return ident +} + + +func (p *parser) parseArrayType(ellipsisOk bool) ast.Expr { + if p.trace { + defer un(trace(p, "ArrayType")) + } + + lbrack := p.expect(token.LBRACK) + var len ast.Expr + if ellipsisOk && p.tok == token.ELLIPSIS { + len = &ast.Ellipsis{p.pos, nil} + p.next() + } else if p.tok != token.RBRACK { + len = p.parseRhs() + } + p.expect(token.RBRACK) + elt := p.parseType() + + return &ast.ArrayType{lbrack, len, elt} +} + + +func (p *parser) makeIdentList(list []ast.Expr) []*ast.Ident { + idents := make([]*ast.Ident, len(list)) + for i, x := range list { + ident, isIdent := x.(*ast.Ident) + if !isIdent { + pos := x.(ast.Expr).Pos() + p.errorExpected(pos, "identifier") + ident = &ast.Ident{pos, "_", nil} + } + idents[i] = ident + } + return idents +} + + +func (p *parser) parseFieldDecl(scope *ast.Scope) *ast.Field { + if p.trace { + defer un(trace(p, "FieldDecl")) + } + + doc := p.leadComment + + // fields + list, typ := p.parseVarList(false) + + // optional tag + var tag *ast.BasicLit + if p.tok == token.STRING { + tag = &ast.BasicLit{p.pos, p.tok, p.lit} + p.next() + } + + // analyze case + var idents []*ast.Ident + if typ != nil { + // IdentifierList Type + idents = p.makeIdentList(list) + } else { + // ["*"] TypeName (AnonymousField) + typ = list[0] // we always have at least one element + p.resolve(typ) + if n := len(list); n > 1 || !isTypeName(deref(typ)) { + pos := typ.Pos() + p.errorExpected(pos, "anonymous field") + typ = &ast.BadExpr{pos, list[n-1].End()} + } + } + + p.expectSemi() // call before accessing p.linecomment + + field := &ast.Field{doc, idents, typ, tag, p.lineComment} + p.declare(field, scope, ast.Var, idents...) + + return field +} + + +func (p *parser) parseStructType() *ast.StructType { + if p.trace { + defer un(trace(p, "StructType")) + } + + pos := p.expect(token.STRUCT) + lbrace := p.expect(token.LBRACE) + scope := ast.NewScope(nil) // struct scope + var list []*ast.Field + for p.tok == token.IDENT || p.tok == token.MUL || p.tok == token.LPAREN { + // a field declaration cannot start with a '(' but we accept + // it here for more robust parsing and better error messages + // (parseFieldDecl will check and complain if necessary) + list = append(list, p.parseFieldDecl(scope)) + } + rbrace := p.expect(token.RBRACE) + + // TODO(gri): store struct scope in AST + return &ast.StructType{pos, &ast.FieldList{lbrace, list, rbrace}, false} +} + + +func (p *parser) parsePointerType() *ast.StarExpr { + if p.trace { + defer un(trace(p, "PointerType")) + } + + star := p.expect(token.MUL) + base := p.parseType() + + return &ast.StarExpr{star, base} +} + + +func (p *parser) tryVarType(isParam bool) ast.Expr { + if isParam && p.tok == token.ELLIPSIS { + pos := p.pos + p.next() + typ := p.tryIdentOrType(isParam) // don't use parseType so we can provide better error message + if typ == nil { + p.error(pos, "'...' parameter is missing type") + typ = &ast.BadExpr{pos, p.pos} + } + if p.tok != token.RPAREN { + p.error(pos, "can use '...' with last parameter type only") + } + return &ast.Ellipsis{pos, typ} + } + return p.tryIdentOrType(false) +} + + +func (p *parser) parseVarType(isParam bool) ast.Expr { + typ := p.tryVarType(isParam) + if typ == nil { + pos := p.pos + p.errorExpected(pos, "type") + p.next() // make progress + typ = &ast.BadExpr{pos, p.pos} + } + return typ +} + + +func (p *parser) parseVarList(isParam bool) (list []ast.Expr, typ ast.Expr) { + if p.trace { + defer un(trace(p, "VarList")) + } + + // a list of identifiers looks like a list of type names + for { + // parseVarType accepts any type (including parenthesized ones) + // even though the syntax does not permit them here: we + // accept them all for more robust parsing and complain + // afterwards + list = append(list, p.parseVarType(isParam)) + if p.tok != token.COMMA { + break + } + p.next() + } + + // if we had a list of identifiers, it must be followed by a type + typ = p.tryVarType(isParam) + if typ != nil { + p.resolve(typ) + } + + return +} + + +func (p *parser) parseParameterList(scope *ast.Scope, ellipsisOk bool) (params []*ast.Field) { + if p.trace { + defer un(trace(p, "ParameterList")) + } + + list, typ := p.parseVarList(ellipsisOk) + if typ != nil { + // IdentifierList Type + idents := p.makeIdentList(list) + field := &ast.Field{nil, idents, typ, nil, nil} + params = append(params, field) + // Go spec: The scope of an identifier denoting a function + // parameter or result variable is the function body. + p.declare(field, scope, ast.Var, idents...) + if p.tok == token.COMMA { + p.next() + } + + for p.tok != token.RPAREN && p.tok != token.EOF { + idents := p.parseIdentList() + typ := p.parseVarType(ellipsisOk) + field := &ast.Field{nil, idents, typ, nil, nil} + params = append(params, field) + // Go spec: The scope of an identifier denoting a function + // parameter or result variable is the function body. + p.declare(field, scope, ast.Var, idents...) + if p.tok != token.COMMA { + break + } + p.next() + } + + } else { + // Type { "," Type } (anonymous parameters) + params = make([]*ast.Field, len(list)) + for i, x := range list { + p.resolve(x) + params[i] = &ast.Field{Type: x} + } + } + + return +} + + +func (p *parser) parseParameters(scope *ast.Scope, ellipsisOk bool) *ast.FieldList { + if p.trace { + defer un(trace(p, "Parameters")) + } + + var params []*ast.Field + lparen := p.expect(token.LPAREN) + if p.tok != token.RPAREN { + params = p.parseParameterList(scope, ellipsisOk) + } + rparen := p.expect(token.RPAREN) + + return &ast.FieldList{lparen, params, rparen} +} + + +func (p *parser) parseResult(scope *ast.Scope) *ast.FieldList { + if p.trace { + defer un(trace(p, "Result")) + } + + if p.tok == token.LPAREN { + return p.parseParameters(scope, false) + } + + typ := p.tryType() + if typ != nil { + list := make([]*ast.Field, 1) + list[0] = &ast.Field{Type: typ} + return &ast.FieldList{List: list} + } + + return nil +} + + +func (p *parser) parseSignature(scope *ast.Scope) (params, results *ast.FieldList) { + if p.trace { + defer un(trace(p, "Signature")) + } + + params = p.parseParameters(scope, true) + results = p.parseResult(scope) + + return +} + + +func (p *parser) parseFuncType() (*ast.FuncType, *ast.Scope) { + if p.trace { + defer un(trace(p, "FuncType")) + } + + pos := p.expect(token.FUNC) + scope := ast.NewScope(p.topScope) // function scope + params, results := p.parseSignature(scope) + + return &ast.FuncType{pos, params, results}, scope +} + + +func (p *parser) parseMethodSpec(scope *ast.Scope) *ast.Field { + if p.trace { + defer un(trace(p, "MethodSpec")) + } + + doc := p.leadComment + var idents []*ast.Ident + var typ ast.Expr + x := p.parseTypeName() + if ident, isIdent := x.(*ast.Ident); isIdent && p.tok == token.LPAREN { + // method + idents = []*ast.Ident{ident} + scope := ast.NewScope(nil) // method scope + params, results := p.parseSignature(scope) + typ = &ast.FuncType{token.NoPos, params, results} + } else { + // embedded interface + typ = x + } + p.expectSemi() // call before accessing p.linecomment + + spec := &ast.Field{doc, idents, typ, nil, p.lineComment} + p.declare(spec, scope, ast.Fun, idents...) + + return spec +} + + +func (p *parser) parseInterfaceType() *ast.InterfaceType { + if p.trace { + defer un(trace(p, "InterfaceType")) + } + + pos := p.expect(token.INTERFACE) + lbrace := p.expect(token.LBRACE) + scope := ast.NewScope(nil) // interface scope + var list []*ast.Field + for p.tok == token.IDENT { + list = append(list, p.parseMethodSpec(scope)) + } + rbrace := p.expect(token.RBRACE) + + // TODO(gri): store interface scope in AST + return &ast.InterfaceType{pos, &ast.FieldList{lbrace, list, rbrace}, false} +} + + +func (p *parser) parseMapType() *ast.MapType { + if p.trace { + defer un(trace(p, "MapType")) + } + + pos := p.expect(token.MAP) + p.expect(token.LBRACK) + key := p.parseType() + p.expect(token.RBRACK) + value := p.parseType() + + return &ast.MapType{pos, key, value} +} + + +func (p *parser) parseChanType() *ast.ChanType { + if p.trace { + defer un(trace(p, "ChanType")) + } + + pos := p.pos + dir := ast.SEND | ast.RECV + if p.tok == token.CHAN { + p.next() + if p.tok == token.ARROW { + p.next() + dir = ast.SEND + } + } else { + p.expect(token.ARROW) + p.expect(token.CHAN) + dir = ast.RECV + } + value := p.parseType() + + return &ast.ChanType{pos, dir, value} +} + + +// If the result is an identifier, it is not resolved. +func (p *parser) tryIdentOrType(ellipsisOk bool) ast.Expr { + switch p.tok { + case token.IDENT: + return p.parseTypeName() + case token.LBRACK: + return p.parseArrayType(ellipsisOk) + case token.STRUCT: + return p.parseStructType() + case token.MUL: + return p.parsePointerType() + case token.FUNC: + typ, _ := p.parseFuncType() + return typ + case token.INTERFACE: + return p.parseInterfaceType() + case token.MAP: + return p.parseMapType() + case token.CHAN, token.ARROW: + return p.parseChanType() + case token.LPAREN: + lparen := p.pos + p.next() + typ := p.parseType() + rparen := p.expect(token.RPAREN) + return &ast.ParenExpr{lparen, typ, rparen} + } + + // no type found + return nil +} + + +func (p *parser) tryType() ast.Expr { + typ := p.tryIdentOrType(false) + if typ != nil { + p.resolve(typ) + } + return typ +} + + +// ---------------------------------------------------------------------------- +// Blocks + +func (p *parser) parseStmtList() (list []ast.Stmt) { + if p.trace { + defer un(trace(p, "StatementList")) + } + + for p.tok != token.CASE && p.tok != token.DEFAULT && p.tok != token.RBRACE && p.tok != token.EOF { + list = append(list, p.parseStmt()) + } + + return +} + + +func (p *parser) parseBody(scope *ast.Scope) *ast.BlockStmt { + if p.trace { + defer un(trace(p, "Body")) + } + + lbrace := p.expect(token.LBRACE) + p.topScope = scope // open function scope + p.openLabelScope() + list := p.parseStmtList() + p.closeLabelScope() + p.closeScope() + rbrace := p.expect(token.RBRACE) + + return &ast.BlockStmt{lbrace, list, rbrace} +} + + +func (p *parser) parseBlockStmt() *ast.BlockStmt { + if p.trace { + defer un(trace(p, "BlockStmt")) + } + + lbrace := p.expect(token.LBRACE) + p.openScope() + list := p.parseStmtList() + p.closeScope() + rbrace := p.expect(token.RBRACE) + + return &ast.BlockStmt{lbrace, list, rbrace} +} + + +// ---------------------------------------------------------------------------- +// Expressions + +func (p *parser) parseFuncTypeOrLit() ast.Expr { + if p.trace { + defer un(trace(p, "FuncTypeOrLit")) + } + + typ, scope := p.parseFuncType() + if p.tok != token.LBRACE { + // function type only + return typ + } + + p.exprLev++ + body := p.parseBody(scope) + p.exprLev-- + + return &ast.FuncLit{typ, body} +} + + +// parseOperand may return an expression or a raw type (incl. array +// types of the form [...]T. Callers must verify the result. +// If lhs is set and the result is an identifier, it is not resolved. +// +func (p *parser) parseOperand(lhs bool) ast.Expr { + if p.trace { + defer un(trace(p, "Operand")) + } + + switch p.tok { + case token.IDENT: + x := p.parseIdent() + if !lhs { + p.resolve(x) + } + return x + + case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING: + x := &ast.BasicLit{p.pos, p.tok, p.lit} + p.next() + return x + + case token.LPAREN: + lparen := p.pos + p.next() + p.exprLev++ + x := p.parseRhs() + p.exprLev-- + rparen := p.expect(token.RPAREN) + return &ast.ParenExpr{lparen, x, rparen} + + case token.FUNC: + return p.parseFuncTypeOrLit() + + default: + if typ := p.tryIdentOrType(true); typ != nil { + // could be type for composite literal or conversion + _, isIdent := typ.(*ast.Ident) + assert(!isIdent, "type cannot be identifier") + return typ + } + } + + pos := p.pos + p.errorExpected(pos, "operand") + p.next() // make progress + return &ast.BadExpr{pos, p.pos} +} + + +func (p *parser) parseSelector(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "Selector")) + } + + sel := p.parseIdent() + + return &ast.SelectorExpr{x, sel} +} + + +func (p *parser) parseTypeAssertion(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "TypeAssertion")) + } + + p.expect(token.LPAREN) + var typ ast.Expr + if p.tok == token.TYPE { + // type switch: typ == nil + p.next() + } else { + typ = p.parseType() + } + p.expect(token.RPAREN) + + return &ast.TypeAssertExpr{x, typ} +} + + +func (p *parser) parseIndexOrSlice(x ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "IndexOrSlice")) + } + + lbrack := p.expect(token.LBRACK) + p.exprLev++ + var low, high ast.Expr + isSlice := false + if p.tok != token.COLON { + low = p.parseRhs() + } + if p.tok == token.COLON { + isSlice = true + p.next() + if p.tok != token.RBRACK { + high = p.parseRhs() + } + } + p.exprLev-- + rbrack := p.expect(token.RBRACK) + + if isSlice { + return &ast.SliceExpr{x, lbrack, low, high, rbrack} + } + return &ast.IndexExpr{x, lbrack, low, rbrack} +} + + +func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr { + if p.trace { + defer un(trace(p, "CallOrConversion")) + } + + lparen := p.expect(token.LPAREN) + p.exprLev++ + var list []ast.Expr + var ellipsis token.Pos + for p.tok != token.RPAREN && p.tok != token.EOF && !ellipsis.IsValid() { + list = append(list, p.parseRhs()) + if p.tok == token.ELLIPSIS { + ellipsis = p.pos + p.next() + } + if p.tok != token.COMMA { + break + } + p.next() + } + p.exprLev-- + rparen := p.expect(token.RPAREN) + + return &ast.CallExpr{fun, lparen, list, ellipsis, rparen} +} + + +func (p *parser) parseElement(keyOk bool) ast.Expr { + if p.trace { + defer un(trace(p, "Element")) + } + + if p.tok == token.LBRACE { + return p.parseLiteralValue(nil) + } + + x := p.parseExpr(keyOk) // don't resolve if map key + if keyOk { + if p.tok == token.COLON { + colon := p.pos + p.next() + return &ast.KeyValueExpr{x, colon, p.parseElement(false)} + } + p.resolve(x) // not a map key + } + + return x +} + + +func (p *parser) parseElementList() (list []ast.Expr) { + if p.trace { + defer un(trace(p, "ElementList")) + } + + for p.tok != token.RBRACE && p.tok != token.EOF { + list = append(list, p.parseElement(true)) + if p.tok != token.COMMA { + break + } + p.next() + } + + return +} + + +func (p *parser) parseLiteralValue(typ ast.Expr) ast.Expr { + if p.trace { + defer un(trace(p, "LiteralValue")) + } + + lbrace := p.expect(token.LBRACE) + var elts []ast.Expr + p.exprLev++ + if p.tok != token.RBRACE { + elts = p.parseElementList() + } + p.exprLev-- + rbrace := p.expect(token.RBRACE) + return &ast.CompositeLit{typ, lbrace, elts, rbrace} +} + + +// checkExpr checks that x is an expression (and not a type). +func (p *parser) checkExpr(x ast.Expr) ast.Expr { + switch t := unparen(x).(type) { + case *ast.BadExpr: + case *ast.Ident: + case *ast.BasicLit: + case *ast.FuncLit: + case *ast.CompositeLit: + case *ast.ParenExpr: + panic("unreachable") + case *ast.SelectorExpr: + case *ast.IndexExpr: + case *ast.SliceExpr: + case *ast.TypeAssertExpr: + if t.Type == nil { + // the form X.(type) is only allowed in type switch expressions + p.errorExpected(x.Pos(), "expression") + x = &ast.BadExpr{x.Pos(), x.End()} + } + case *ast.CallExpr: + case *ast.StarExpr: + case *ast.UnaryExpr: + if t.Op == token.RANGE { + // the range operator is only allowed at the top of a for statement + p.errorExpected(x.Pos(), "expression") + x = &ast.BadExpr{x.Pos(), x.End()} + } + case *ast.BinaryExpr: + default: + // all other nodes are not proper expressions + p.errorExpected(x.Pos(), "expression") + x = &ast.BadExpr{x.Pos(), x.End()} + } + return x +} + + +// isTypeName returns true iff x is a (qualified) TypeName. +func isTypeName(x ast.Expr) bool { + switch t := x.(type) { + case *ast.BadExpr: + case *ast.Ident: + case *ast.SelectorExpr: + _, isIdent := t.X.(*ast.Ident) + return isIdent + default: + return false // all other nodes are not type names + } + return true +} + + +// isLiteralType returns true iff x is a legal composite literal type. +func isLiteralType(x ast.Expr) bool { + switch t := x.(type) { + case *ast.BadExpr: + case *ast.Ident: + case *ast.SelectorExpr: + _, isIdent := t.X.(*ast.Ident) + return isIdent + case *ast.ArrayType: + case *ast.StructType: + case *ast.MapType: + default: + return false // all other nodes are not legal composite literal types + } + return true +} + + +// If x is of the form *T, deref returns T, otherwise it returns x. +func deref(x ast.Expr) ast.Expr { + if p, isPtr := x.(*ast.StarExpr); isPtr { + x = p.X + } + return x +} + + +// If x is of the form (T), unparen returns unparen(T), otherwise it returns x. +func unparen(x ast.Expr) ast.Expr { + if p, isParen := x.(*ast.ParenExpr); isParen { + x = unparen(p.X) + } + return x +} + + +// checkExprOrType checks that x is an expression or a type +// (and not a raw type such as [...]T). +// +func (p *parser) checkExprOrType(x ast.Expr) ast.Expr { + switch t := unparen(x).(type) { + case *ast.ParenExpr: + panic("unreachable") + case *ast.UnaryExpr: + if t.Op == token.RANGE { + // the range operator is only allowed at the top of a for statement + p.errorExpected(x.Pos(), "expression") + x = &ast.BadExpr{x.Pos(), x.End()} + } + case *ast.ArrayType: + if len, isEllipsis := t.Len.(*ast.Ellipsis); isEllipsis { + p.error(len.Pos(), "expected array length, found '...'") + x = &ast.BadExpr{x.Pos(), x.End()} + } + } + + // all other nodes are expressions or types + return x +} + + +// If lhs is set and the result is an identifier, it is not resolved. +func (p *parser) parsePrimaryExpr(lhs bool) ast.Expr { + if p.trace { + defer un(trace(p, "PrimaryExpr")) + } + + x := p.parseOperand(lhs) +L: + for { + switch p.tok { + case token.PERIOD: + p.next() + if lhs { + p.resolve(x) + } + switch p.tok { + case token.IDENT: + x = p.parseSelector(p.checkExpr(x)) + case token.LPAREN: + x = p.parseTypeAssertion(p.checkExpr(x)) + default: + pos := p.pos + p.next() // make progress + p.errorExpected(pos, "selector or type assertion") + x = &ast.BadExpr{pos, p.pos} + } + case token.LBRACK: + if lhs { + p.resolve(x) + } + x = p.parseIndexOrSlice(p.checkExpr(x)) + case token.LPAREN: + if lhs { + p.resolve(x) + } + x = p.parseCallOrConversion(p.checkExprOrType(x)) + case token.LBRACE: + if isLiteralType(x) && (p.exprLev >= 0 || !isTypeName(x)) { + if lhs { + p.resolve(x) + } + x = p.parseLiteralValue(x) + } else { + break L + } + default: + break L + } + lhs = false // no need to try to resolve again + } + + return x +} + + +// If lhs is set and the result is an identifier, it is not resolved. +func (p *parser) parseUnaryExpr(lhs bool) ast.Expr { + if p.trace { + defer un(trace(p, "UnaryExpr")) + } + + switch p.tok { + case token.ADD, token.SUB, token.NOT, token.XOR, token.AND, token.RANGE: + pos, op := p.pos, p.tok + p.next() + x := p.parseUnaryExpr(false) + return &ast.UnaryExpr{pos, op, p.checkExpr(x)} + + case token.ARROW: + // channel type or receive expression + pos := p.pos + p.next() + if p.tok == token.CHAN { + p.next() + value := p.parseType() + return &ast.ChanType{pos, ast.RECV, value} + } + + x := p.parseUnaryExpr(false) + return &ast.UnaryExpr{pos, token.ARROW, p.checkExpr(x)} + + case token.MUL: + // pointer type or unary "*" expression + pos := p.pos + p.next() + x := p.parseUnaryExpr(false) + return &ast.StarExpr{pos, p.checkExprOrType(x)} + } + + return p.parsePrimaryExpr(lhs) +} + + +// If lhs is set and the result is an identifier, it is not resolved. +func (p *parser) parseBinaryExpr(lhs bool, prec1 int) ast.Expr { + if p.trace { + defer un(trace(p, "BinaryExpr")) + } + + x := p.parseUnaryExpr(lhs) + for prec := p.tok.Precedence(); prec >= prec1; prec-- { + for p.tok.Precedence() == prec { + pos, op := p.pos, p.tok + p.next() + if lhs { + p.resolve(x) + lhs = false + } + y := p.parseBinaryExpr(false, prec+1) + x = &ast.BinaryExpr{p.checkExpr(x), pos, op, p.checkExpr(y)} + } + } + + return x +} + + +// If lhs is set and the result is an identifier, it is not resolved. +// TODO(gri): parseExpr may return a type or even a raw type ([..]int) - +// should reject when a type/raw type is obviously not allowed +func (p *parser) parseExpr(lhs bool) ast.Expr { + if p.trace { + defer un(trace(p, "Expression")) + } + + return p.parseBinaryExpr(lhs, token.LowestPrec+1) +} + + +func (p *parser) parseRhs() ast.Expr { + return p.parseExpr(false) +} + + +// ---------------------------------------------------------------------------- +// Statements + +func (p *parser) parseSimpleStmt(labelOk bool) ast.Stmt { + if p.trace { + defer un(trace(p, "SimpleStmt")) + } + + x := p.parseLhsList() + + switch p.tok { + case + token.DEFINE, token.ASSIGN, token.ADD_ASSIGN, + token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN, + token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN, + token.XOR_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN: + // assignment statement + pos, tok := p.pos, p.tok + p.next() + y := p.parseRhsList() + return &ast.AssignStmt{x, pos, tok, y} + } + + if len(x) > 1 { + p.errorExpected(x[0].Pos(), "1 expression") + // continue with first expression + } + + switch p.tok { + case token.COLON: + // labeled statement + colon := p.pos + p.next() + if label, isIdent := x[0].(*ast.Ident); labelOk && isIdent { + // Go spec: The scope of a label is the body of the function + // in which it is declared and excludes the body of any nested + // function. + stmt := &ast.LabeledStmt{label, colon, p.parseStmt()} + p.declare(stmt, p.labelScope, ast.Lbl, label) + return stmt + } + p.error(x[0].Pos(), "illegal label declaration") + return &ast.BadStmt{x[0].Pos(), colon + 1} + + case token.ARROW: + // send statement + arrow := p.pos + p.next() // consume "<-" + y := p.parseRhs() + return &ast.SendStmt{x[0], arrow, y} + + case token.INC, token.DEC: + // increment or decrement + s := &ast.IncDecStmt{x[0], p.pos, p.tok} + p.next() // consume "++" or "--" + return s + } + + // expression + return &ast.ExprStmt{x[0]} +} + + +func (p *parser) parseCallExpr() *ast.CallExpr { + x := p.parseRhs() + if call, isCall := x.(*ast.CallExpr); isCall { + return call + } + p.errorExpected(x.Pos(), "function/method call") + return nil +} + + +func (p *parser) parseGoStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "GoStmt")) + } + + pos := p.expect(token.GO) + call := p.parseCallExpr() + p.expectSemi() + if call == nil { + return &ast.BadStmt{pos, pos + 2} // len("go") + } + + return &ast.GoStmt{pos, call} +} + + +func (p *parser) parseDeferStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "DeferStmt")) + } + + pos := p.expect(token.DEFER) + call := p.parseCallExpr() + p.expectSemi() + if call == nil { + return &ast.BadStmt{pos, pos + 5} // len("defer") + } + + return &ast.DeferStmt{pos, call} +} + + +func (p *parser) parseReturnStmt() *ast.ReturnStmt { + if p.trace { + defer un(trace(p, "ReturnStmt")) + } + + pos := p.pos + p.expect(token.RETURN) + var x []ast.Expr + if p.tok != token.SEMICOLON && p.tok != token.RBRACE { + x = p.parseRhsList() + } + p.expectSemi() + + return &ast.ReturnStmt{pos, x} +} + + +func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt { + if p.trace { + defer un(trace(p, "BranchStmt")) + } + + pos := p.expect(tok) + var label *ast.Ident + if tok != token.FALLTHROUGH && p.tok == token.IDENT { + label = p.parseIdent() + // add to list of unresolved targets + n := len(p.targetStack) - 1 + p.targetStack[n] = append(p.targetStack[n], label) + } + p.expectSemi() + + return &ast.BranchStmt{pos, tok, label} +} + + +func (p *parser) makeExpr(s ast.Stmt) ast.Expr { + if s == nil { + return nil + } + if es, isExpr := s.(*ast.ExprStmt); isExpr { + return p.checkExpr(es.X) + } + p.error(s.Pos(), "expected condition, found simple statement") + return &ast.BadExpr{s.Pos(), s.End()} +} + + +func (p *parser) parseIfStmt() *ast.IfStmt { + if p.trace { + defer un(trace(p, "IfStmt")) + } + + pos := p.expect(token.IF) + p.openScope() + defer p.closeScope() + + var s ast.Stmt + var x ast.Expr + { + prevLev := p.exprLev + p.exprLev = -1 + if p.tok == token.SEMICOLON { + p.next() + x = p.parseRhs() + } else { + s = p.parseSimpleStmt(false) + if p.tok == token.SEMICOLON { + p.next() + x = p.parseRhs() + } else { + x = p.makeExpr(s) + s = nil + } + } + p.exprLev = prevLev + } + + body := p.parseBlockStmt() + var else_ ast.Stmt + if p.tok == token.ELSE { + p.next() + else_ = p.parseStmt() + } else { + p.expectSemi() + } + + return &ast.IfStmt{pos, s, x, body, else_} +} + + +func (p *parser) parseTypeList() (list []ast.Expr) { + if p.trace { + defer un(trace(p, "TypeList")) + } + + list = append(list, p.parseType()) + for p.tok == token.COMMA { + p.next() + list = append(list, p.parseType()) + } + + return +} + + +func (p *parser) parseCaseClause(exprSwitch bool) *ast.CaseClause { + if p.trace { + defer un(trace(p, "CaseClause")) + } + + pos := p.pos + var list []ast.Expr + if p.tok == token.CASE { + p.next() + if exprSwitch { + list = p.parseRhsList() + } else { + list = p.parseTypeList() + } + } else { + p.expect(token.DEFAULT) + } + + colon := p.expect(token.COLON) + p.openScope() + body := p.parseStmtList() + p.closeScope() + + return &ast.CaseClause{pos, list, colon, body} +} + + +func isExprSwitch(s ast.Stmt) bool { + if s == nil { + return true + } + if e, ok := s.(*ast.ExprStmt); ok { + if a, ok := e.X.(*ast.TypeAssertExpr); ok { + return a.Type != nil // regular type assertion + } + return true + } + return false +} + + +func (p *parser) parseSwitchStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "SwitchStmt")) + } + + pos := p.expect(token.SWITCH) + p.openScope() + defer p.closeScope() + + var s1, s2 ast.Stmt + if p.tok != token.LBRACE { + prevLev := p.exprLev + p.exprLev = -1 + if p.tok != token.SEMICOLON { + s2 = p.parseSimpleStmt(false) + } + if p.tok == token.SEMICOLON { + p.next() + s1 = s2 + s2 = nil + if p.tok != token.LBRACE { + s2 = p.parseSimpleStmt(false) + } + } + p.exprLev = prevLev + } + + exprSwitch := isExprSwitch(s2) + lbrace := p.expect(token.LBRACE) + var list []ast.Stmt + for p.tok == token.CASE || p.tok == token.DEFAULT { + list = append(list, p.parseCaseClause(exprSwitch)) + } + rbrace := p.expect(token.RBRACE) + p.expectSemi() + body := &ast.BlockStmt{lbrace, list, rbrace} + + if exprSwitch { + return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body} + } + // type switch + // TODO(gri): do all the checks! + return &ast.TypeSwitchStmt{pos, s1, s2, body} +} + + +func (p *parser) parseCommClause() *ast.CommClause { + if p.trace { + defer un(trace(p, "CommClause")) + } + + p.openScope() + pos := p.pos + var comm ast.Stmt + if p.tok == token.CASE { + p.next() + lhs := p.parseLhsList() + if p.tok == token.ARROW { + // SendStmt + if len(lhs) > 1 { + p.errorExpected(lhs[0].Pos(), "1 expression") + // continue with first expression + } + arrow := p.pos + p.next() + rhs := p.parseRhs() + comm = &ast.SendStmt{lhs[0], arrow, rhs} + } else { + // RecvStmt + pos := p.pos + tok := p.tok + var rhs ast.Expr + if tok == token.ASSIGN || tok == token.DEFINE { + // RecvStmt with assignment + if len(lhs) > 2 { + p.errorExpected(lhs[0].Pos(), "1 or 2 expressions") + // continue with first two expressions + lhs = lhs[0:2] + } + p.next() + rhs = p.parseRhs() + } else { + // rhs must be single receive operation + if len(lhs) > 1 { + p.errorExpected(lhs[0].Pos(), "1 expression") + // continue with first expression + } + rhs = lhs[0] + lhs = nil // there is no lhs + } + if x, isUnary := rhs.(*ast.UnaryExpr); !isUnary || x.Op != token.ARROW { + p.errorExpected(rhs.Pos(), "send or receive operation") + rhs = &ast.BadExpr{rhs.Pos(), rhs.End()} + } + if lhs != nil { + comm = &ast.AssignStmt{lhs, pos, tok, []ast.Expr{rhs}} + } else { + comm = &ast.ExprStmt{rhs} + } + } + } else { + p.expect(token.DEFAULT) + } + + colon := p.expect(token.COLON) + body := p.parseStmtList() + p.closeScope() + + return &ast.CommClause{pos, comm, colon, body} +} + + +func (p *parser) parseSelectStmt() *ast.SelectStmt { + if p.trace { + defer un(trace(p, "SelectStmt")) + } + + pos := p.expect(token.SELECT) + lbrace := p.expect(token.LBRACE) + var list []ast.Stmt + for p.tok == token.CASE || p.tok == token.DEFAULT { + list = append(list, p.parseCommClause()) + } + rbrace := p.expect(token.RBRACE) + p.expectSemi() + body := &ast.BlockStmt{lbrace, list, rbrace} + + return &ast.SelectStmt{pos, body} +} + + +func (p *parser) parseForStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "ForStmt")) + } + + pos := p.expect(token.FOR) + p.openScope() + defer p.closeScope() + + var s1, s2, s3 ast.Stmt + if p.tok != token.LBRACE { + prevLev := p.exprLev + p.exprLev = -1 + if p.tok != token.SEMICOLON { + s2 = p.parseSimpleStmt(false) + } + if p.tok == token.SEMICOLON { + p.next() + s1 = s2 + s2 = nil + if p.tok != token.SEMICOLON { + s2 = p.parseSimpleStmt(false) + } + p.expectSemi() + if p.tok != token.LBRACE { + s3 = p.parseSimpleStmt(false) + } + } + p.exprLev = prevLev + } + + body := p.parseBlockStmt() + p.expectSemi() + + if as, isAssign := s2.(*ast.AssignStmt); isAssign { + // possibly a for statement with a range clause; check assignment operator + if as.Tok != token.ASSIGN && as.Tok != token.DEFINE { + p.errorExpected(as.TokPos, "'=' or ':='") + return &ast.BadStmt{pos, body.End()} + } + // check lhs + var key, value ast.Expr + switch len(as.Lhs) { + case 2: + key, value = as.Lhs[0], as.Lhs[1] + case 1: + key = as.Lhs[0] + default: + p.errorExpected(as.Lhs[0].Pos(), "1 or 2 expressions") + return &ast.BadStmt{pos, body.End()} + } + // check rhs + if len(as.Rhs) != 1 { + p.errorExpected(as.Rhs[0].Pos(), "1 expression") + return &ast.BadStmt{pos, body.End()} + } + if rhs, isUnary := as.Rhs[0].(*ast.UnaryExpr); isUnary && rhs.Op == token.RANGE { + // rhs is range expression + // (any short variable declaration was handled by parseSimpleStat above) + return &ast.RangeStmt{pos, key, value, as.TokPos, as.Tok, rhs.X, body} + } + p.errorExpected(s2.Pos(), "range clause") + return &ast.BadStmt{pos, body.End()} + } + + // regular for statement + return &ast.ForStmt{pos, s1, p.makeExpr(s2), s3, body} +} + + +func (p *parser) parseStmt() (s ast.Stmt) { + if p.trace { + defer un(trace(p, "Statement")) + } + + switch p.tok { + case token.CONST, token.TYPE, token.VAR: + s = &ast.DeclStmt{p.parseDecl()} + case + // tokens that may start a top-level expression + token.IDENT, token.INT, token.FLOAT, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operand + token.LBRACK, token.STRUCT, // composite type + token.MUL, token.AND, token.ARROW, token.ADD, token.SUB, token.XOR: // unary operators + s = p.parseSimpleStmt(true) + // because of the required look-ahead, labeled statements are + // parsed by parseSimpleStmt - don't expect a semicolon after + // them + if _, isLabeledStmt := s.(*ast.LabeledStmt); !isLabeledStmt { + p.expectSemi() + } + case token.GO: + s = p.parseGoStmt() + case token.DEFER: + s = p.parseDeferStmt() + case token.RETURN: + s = p.parseReturnStmt() + case token.BREAK, token.CONTINUE, token.GOTO, token.FALLTHROUGH: + s = p.parseBranchStmt(p.tok) + case token.LBRACE: + s = p.parseBlockStmt() + p.expectSemi() + case token.IF: + s = p.parseIfStmt() + case token.SWITCH: + s = p.parseSwitchStmt() + case token.SELECT: + s = p.parseSelectStmt() + case token.FOR: + s = p.parseForStmt() + case token.SEMICOLON: + s = &ast.EmptyStmt{p.pos} + p.next() + case token.RBRACE: + // a semicolon may be omitted before a closing "}" + s = &ast.EmptyStmt{p.pos} + default: + // no statement found + pos := p.pos + p.errorExpected(pos, "statement") + p.next() // make progress + s = &ast.BadStmt{pos, p.pos} + } + + return +} + + +// ---------------------------------------------------------------------------- +// Declarations + +type parseSpecFunction func(p *parser, doc *ast.CommentGroup, iota int) ast.Spec + + +func parseImportSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec { + if p.trace { + defer un(trace(p, "ImportSpec")) + } + + var ident *ast.Ident + switch p.tok { + case token.PERIOD: + ident = &ast.Ident{p.pos, ".", nil} + p.next() + case token.IDENT: + ident = p.parseIdent() + } + + var path *ast.BasicLit + if p.tok == token.STRING { + path = &ast.BasicLit{p.pos, p.tok, p.lit} + p.next() + } else { + p.expect(token.STRING) // use expect() error handling + } + p.expectSemi() // call before accessing p.linecomment + + // collect imports + spec := &ast.ImportSpec{doc, ident, path, p.lineComment} + p.imports = append(p.imports, spec) + + return spec +} + + +func parseConstSpec(p *parser, doc *ast.CommentGroup, iota int) ast.Spec { + if p.trace { + defer un(trace(p, "ConstSpec")) + } + + idents := p.parseIdentList() + typ := p.tryType() + var values []ast.Expr + if typ != nil || p.tok == token.ASSIGN || iota == 0 { + p.expect(token.ASSIGN) + values = p.parseRhsList() + } + p.expectSemi() // call before accessing p.linecomment + + // Go spec: The scope of a constant or variable identifier declared inside + // a function begins at the end of the ConstSpec or VarSpec and ends at + // the end of the innermost containing block. + // (Global identifiers are resolved in a separate phase after parsing.) + spec := &ast.ValueSpec{doc, idents, typ, values, p.lineComment} + p.declare(spec, p.topScope, ast.Con, idents...) + + return spec +} + + +func parseTypeSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec { + if p.trace { + defer un(trace(p, "TypeSpec")) + } + + ident := p.parseIdent() + + // Go spec: The scope of a type identifier declared inside a function begins + // at the identifier in the TypeSpec and ends at the end of the innermost + // containing block. + // (Global identifiers are resolved in a separate phase after parsing.) + spec := &ast.TypeSpec{doc, ident, nil, nil} + p.declare(spec, p.topScope, ast.Typ, ident) + + spec.Type = p.parseType() + p.expectSemi() // call before accessing p.linecomment + spec.Comment = p.lineComment + + return spec +} + + +func parseVarSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec { + if p.trace { + defer un(trace(p, "VarSpec")) + } + + idents := p.parseIdentList() + typ := p.tryType() + var values []ast.Expr + if typ == nil || p.tok == token.ASSIGN { + p.expect(token.ASSIGN) + values = p.parseRhsList() + } + p.expectSemi() // call before accessing p.linecomment + + // Go spec: The scope of a constant or variable identifier declared inside + // a function begins at the end of the ConstSpec or VarSpec and ends at + // the end of the innermost containing block. + // (Global identifiers are resolved in a separate phase after parsing.) + spec := &ast.ValueSpec{doc, idents, typ, values, p.lineComment} + p.declare(spec, p.topScope, ast.Var, idents...) + + return spec +} + + +func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl { + if p.trace { + defer un(trace(p, "GenDecl("+keyword.String()+")")) + } + + doc := p.leadComment + pos := p.expect(keyword) + var lparen, rparen token.Pos + var list []ast.Spec + if p.tok == token.LPAREN { + lparen = p.pos + p.next() + for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ { + list = append(list, f(p, p.leadComment, iota)) + } + rparen = p.expect(token.RPAREN) + p.expectSemi() + } else { + list = append(list, f(p, nil, 0)) + } + + return &ast.GenDecl{doc, pos, keyword, lparen, list, rparen} +} + + +func (p *parser) parseReceiver(scope *ast.Scope) *ast.FieldList { + if p.trace { + defer un(trace(p, "Receiver")) + } + + pos := p.pos + par := p.parseParameters(scope, false) + + // must have exactly one receiver + if par.NumFields() != 1 { + p.errorExpected(pos, "exactly one receiver") + // TODO determine a better range for BadExpr below + par.List = []*ast.Field{&ast.Field{Type: &ast.BadExpr{pos, pos}}} + return par + } + + // recv type must be of the form ["*"] identifier + recv := par.List[0] + base := deref(recv.Type) + if _, isIdent := base.(*ast.Ident); !isIdent { + p.errorExpected(base.Pos(), "(unqualified) identifier") + par.List = []*ast.Field{&ast.Field{Type: &ast.BadExpr{recv.Pos(), recv.End()}}} + } + + return par +} + + +func (p *parser) parseFuncDecl() *ast.FuncDecl { + if p.trace { + defer un(trace(p, "FunctionDecl")) + } + + doc := p.leadComment + pos := p.expect(token.FUNC) + scope := ast.NewScope(p.topScope) // function scope + + var recv *ast.FieldList + if p.tok == token.LPAREN { + recv = p.parseReceiver(scope) + } + + ident := p.parseIdent() + + params, results := p.parseSignature(scope) + + var body *ast.BlockStmt + if p.tok == token.LBRACE { + body = p.parseBody(scope) + } + p.expectSemi() + + decl := &ast.FuncDecl{doc, recv, ident, &ast.FuncType{pos, params, results}, body} + if recv == nil { + // Go spec: The scope of an identifier denoting a constant, type, + // variable, or function (but not method) declared at top level + // (outside any function) is the package block. + // + // init() functions cannot be referred to and there may + // be more than one - don't put them in the pkgScope + if ident.Name != "init" { + p.declare(decl, p.pkgScope, ast.Fun, ident) + } + } + + return decl +} + + +func (p *parser) parseDecl() ast.Decl { + if p.trace { + defer un(trace(p, "Declaration")) + } + + var f parseSpecFunction + switch p.tok { + case token.CONST: + f = parseConstSpec + + case token.TYPE: + f = parseTypeSpec + + case token.VAR: + f = parseVarSpec + + case token.FUNC: + return p.parseFuncDecl() + + default: + pos := p.pos + p.errorExpected(pos, "declaration") + p.next() // make progress + decl := &ast.BadDecl{pos, p.pos} + return decl + } + + return p.parseGenDecl(p.tok, f) +} + + +func (p *parser) parseDeclList() (list []ast.Decl) { + if p.trace { + defer un(trace(p, "DeclList")) + } + + for p.tok != token.EOF { + list = append(list, p.parseDecl()) + } + + return +} + + +// ---------------------------------------------------------------------------- +// Source files + +func (p *parser) parseFile() *ast.File { + if p.trace { + defer un(trace(p, "File")) + } + + // package clause + doc := p.leadComment + pos := p.expect(token.PACKAGE) + // Go spec: The package clause is not a declaration; + // the package name does not appear in any scope. + ident := p.parseIdent() + if ident.Name == "_" { + p.error(p.pos, "invalid package name _") + } + p.expectSemi() + + var decls []ast.Decl + + // Don't bother parsing the rest if we had errors already. + // Likely not a Go source file at all. + + if p.ErrorCount() == 0 && p.mode&PackageClauseOnly == 0 { + // import decls + for p.tok == token.IMPORT { + decls = append(decls, p.parseGenDecl(token.IMPORT, parseImportSpec)) + } + + if p.mode&ImportsOnly == 0 { + // rest of package body + for p.tok != token.EOF { + decls = append(decls, p.parseDecl()) + } + } + } + + assert(p.topScope == p.pkgScope, "imbalanced scopes") + + // resolve global identifiers within the same file + i := 0 + for _, ident := range p.unresolved { + // i <= index for current ident + assert(ident.Obj == unresolved, "object already resolved") + ident.Obj = p.pkgScope.Lookup(ident.Name) // also removes unresolved sentinel + if ident.Obj == nil { + p.unresolved[i] = ident + i++ + } + } + + // TODO(gri): store p.imports in AST + return &ast.File{doc, pos, ident, decls, p.pkgScope, p.imports, p.unresolved[0:i], p.comments} +} diff --git a/libgo/go/go/scanner/scanner.go b/libgo/go/go/scanner/scanner.go index 2f949ad2568..07b7454c87d 100644 --- a/libgo/go/go/scanner/scanner.go +++ b/libgo/go/go/scanner/scanner.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A scanner for Go source text. Takes a []byte as source which can -// then be tokenized through repeated calls to the Scan function. -// Typical use: +// Package scanner implements a scanner for Go source text. Takes a []byte as +// source which can then be tokenized through repeated calls to the Scan +// function. Typical use: // // var s Scanner // fset := token.NewFileSet() // position information is relative to fset diff --git a/libgo/go/go/token/position.go b/libgo/go/go/token/position.go index 809e53f0aa2..8c35eeb52f7 100644 --- a/libgo/go/go/token/position.go +++ b/libgo/go/go/token/position.go @@ -94,10 +94,14 @@ func searchFiles(a []*File, x int) int { func (s *FileSet) file(p Pos) *File { + if f := s.last; f != nil && f.base <= int(p) && int(p) <= f.base+f.size { + return f + } if i := searchFiles(s.files, int(p)); i >= 0 { f := s.files[i] // f.base <= int(p) by definition of searchFiles if int(p) <= f.base+f.size { + s.last = f return f } } @@ -316,8 +320,26 @@ func (f *File) Position(p Pos) (pos Position) { } -func searchUints(a []int, x int) int { - return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1 +func searchInts(a []int, x int) int { + // This function body is a manually inlined version of: + // + // return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1 + // + // With better compiler optimizations, this may not be needed in the + // future, but at the moment this change improves the go/printer + // benchmark performance by ~30%. This has a direct impact on the + // speed of gofmt and thus seems worthwhile (2011-04-29). + i, j := 0, len(a) + for i < j { + h := i + (j-i)/2 // avoid overflow when computing h + // i ≤ h < j + if a[h] <= x { + i = h + 1 + } else { + j = h + } + } + return i - 1 } @@ -329,14 +351,17 @@ func searchLineInfos(a []lineInfo, x int) int { // info returns the file name, line, and column number for a file offset. func (f *File) info(offset int) (filename string, line, column int) { filename = f.name - if i := searchUints(f.lines, offset); i >= 0 { + if i := searchInts(f.lines, offset); i >= 0 { line, column = i+1, offset-f.lines[i]+1 } - if i := searchLineInfos(f.infos, offset); i >= 0 { - alt := &f.infos[i] - filename = alt.filename - if i := searchUints(f.lines, alt.offset); i >= 0 { - line += alt.line - i - 1 + if len(f.infos) > 0 { + // almost no files have extra line infos + if i := searchLineInfos(f.infos, offset); i >= 0 { + alt := &f.infos[i] + filename = alt.filename + if i := searchInts(f.lines, alt.offset); i >= 0 { + line += alt.line - i - 1 + } } } return @@ -348,10 +373,10 @@ func (f *File) info(offset int) (filename string, line, column int) { // may invoke them concurrently. // type FileSet struct { - mutex sync.RWMutex // protects the file set - base int // base offset for the next file - files []*File // list of files in the order added to the set - index map[*File]int // file -> files index for quick lookup + mutex sync.RWMutex // protects the file set + base int // base offset for the next file + files []*File // list of files in the order added to the set + last *File // cache of last file looked up } @@ -359,7 +384,6 @@ type FileSet struct { func NewFileSet() *FileSet { s := new(FileSet) s.base = 1 // 0 == NoPos - s.index = make(map[*File]int) return s } @@ -405,8 +429,8 @@ func (s *FileSet) AddFile(filename string, base, size int) *File { } // add the file to the file set s.base = base - s.index[f] = len(s.files) s.files = append(s.files, f) + s.last = f return f } diff --git a/libgo/go/go/token/token.go b/libgo/go/go/token/token.go index a5f21df168e..c2ec80ae140 100644 --- a/libgo/go/go/token/token.go +++ b/libgo/go/go/token/token.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package defines constants representing the lexical -// tokens of the Go programming language and basic operations -// on tokens (printing, predicates). +// Package token defines constants representing the lexical tokens of the Go +// programming language and basic operations on tokens (printing, predicates). // package token diff --git a/libgo/go/go/typechecker/typechecker.go b/libgo/go/go/typechecker/typechecker.go index b5e695d973a..b151f5834da 100644 --- a/libgo/go/go/typechecker/typechecker.go +++ b/libgo/go/go/typechecker/typechecker.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// INCOMPLETE PACKAGE. +// DEPRECATED PACKAGE - SEE go/types INSTEAD. // This package implements typechecking of a Go AST. // The result of the typecheck is an augmented AST // with object and type information for each identifier. diff --git a/libgo/go/go/types/const.go b/libgo/go/go/types/const.go new file mode 100644 index 00000000000..6fdc22f6b34 --- /dev/null +++ b/libgo/go/go/types/const.go @@ -0,0 +1,347 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements operations on ideal constants. + +package types + +import ( + "big" + "go/token" + "strconv" +) + + +// TODO(gri) Consider changing the API so Const is an interface +// and operations on consts don't have to type switch. + +// A Const implements an ideal constant Value. +// The zero value z for a Const is not a valid constant value. +type Const struct { + // representation of constant values: + // ideal bool -> bool + // ideal int -> *big.Int + // ideal float -> *big.Rat + // ideal complex -> cmplx + // ideal string -> string + val interface{} +} + + +// Representation of complex values. +type cmplx struct { + re, im *big.Rat +} + + +func assert(cond bool) { + if !cond { + panic("go/types internal error: assertion failed") + } +} + + +// MakeConst makes an ideal constant from a literal +// token and the corresponding literal string. +func MakeConst(tok token.Token, lit string) Const { + switch tok { + case token.INT: + var x big.Int + _, ok := x.SetString(lit, 0) + assert(ok) + return Const{&x} + case token.FLOAT: + var y big.Rat + _, ok := y.SetString(lit) + assert(ok) + return Const{&y} + case token.IMAG: + assert(lit[len(lit)-1] == 'i') + var im big.Rat + _, ok := im.SetString(lit[0 : len(lit)-1]) + assert(ok) + return Const{cmplx{big.NewRat(0, 1), &im}} + case token.CHAR: + assert(lit[0] == '\'' && lit[len(lit)-1] == '\'') + code, _, _, err := strconv.UnquoteChar(lit[1:len(lit)-1], '\'') + assert(err == nil) + return Const{big.NewInt(int64(code))} + case token.STRING: + s, err := strconv.Unquote(lit) + assert(err == nil) + return Const{s} + } + panic("unreachable") +} + + +// MakeZero returns the zero constant for the given type. +func MakeZero(typ *Type) Const { + // TODO(gri) fix this + return Const{0} +} + + +// Match attempts to match the internal constant representations of x and y. +// If the attempt is successful, the result is the values of x and y, +// if necessary converted to have the same internal representation; otherwise +// the results are invalid. +func (x Const) Match(y Const) (u, v Const) { + switch a := x.val.(type) { + case bool: + if _, ok := y.val.(bool); ok { + u, v = x, y + } + case *big.Int: + switch y.val.(type) { + case *big.Int: + u, v = x, y + case *big.Rat: + var z big.Rat + z.SetInt(a) + u, v = Const{&z}, y + case cmplx: + var z big.Rat + z.SetInt(a) + u, v = Const{cmplx{&z, big.NewRat(0, 1)}}, y + } + case *big.Rat: + switch y.val.(type) { + case *big.Int: + v, u = y.Match(x) + case *big.Rat: + u, v = x, y + case cmplx: + u, v = Const{cmplx{a, big.NewRat(0, 0)}}, y + } + case cmplx: + switch y.val.(type) { + case *big.Int, *big.Rat: + v, u = y.Match(x) + case cmplx: + u, v = x, y + } + case string: + if _, ok := y.val.(string); ok { + u, v = x, y + } + default: + panic("unreachable") + } + return +} + + +// Convert attempts to convert the constant x to a given type. +// If the attempt is successful, the result is the new constant; +// otherwise the result is invalid. +func (x Const) Convert(typ *Type) Const { + // TODO(gri) implement this + switch x := x.val.(type) { + case bool: + case *big.Int: + case *big.Rat: + case cmplx: + case string: + } + return x +} + + +func (x Const) String() string { + switch x := x.val.(type) { + case bool: + if x { + return "true" + } + return "false" + case *big.Int: + return x.String() + case *big.Rat: + return x.FloatString(10) // 10 digits of precision after decimal point seems fine + case cmplx: + // TODO(gri) don't print 0 components + return x.re.FloatString(10) + " + " + x.im.FloatString(10) + "i" + case string: + return x + } + panic("unreachable") +} + + +func (x Const) UnaryOp(op token.Token) Const { + panic("unimplemented") +} + + +func (x Const) BinaryOp(op token.Token, y Const) Const { + var z interface{} + switch x := x.val.(type) { + case bool: + z = binaryBoolOp(x, op, y.val.(bool)) + case *big.Int: + z = binaryIntOp(x, op, y.val.(*big.Int)) + case *big.Rat: + z = binaryFloatOp(x, op, y.val.(*big.Rat)) + case cmplx: + z = binaryCmplxOp(x, op, y.val.(cmplx)) + case string: + z = binaryStringOp(x, op, y.val.(string)) + default: + panic("unreachable") + } + return Const{z} +} + + +func binaryBoolOp(x bool, op token.Token, y bool) interface{} { + switch op { + case token.EQL: + return x == y + case token.NEQ: + return x != y + } + panic("unreachable") +} + + +func binaryIntOp(x *big.Int, op token.Token, y *big.Int) interface{} { + var z big.Int + switch op { + case token.ADD: + return z.Add(x, y) + case token.SUB: + return z.Sub(x, y) + case token.MUL: + return z.Mul(x, y) + case token.QUO: + return z.Quo(x, y) + case token.REM: + return z.Rem(x, y) + case token.AND: + return z.And(x, y) + case token.OR: + return z.Or(x, y) + case token.XOR: + return z.Xor(x, y) + case token.AND_NOT: + return z.AndNot(x, y) + case token.SHL: + panic("unimplemented") + case token.SHR: + panic("unimplemented") + case token.EQL: + return x.Cmp(y) == 0 + case token.NEQ: + return x.Cmp(y) != 0 + case token.LSS: + return x.Cmp(y) < 0 + case token.LEQ: + return x.Cmp(y) <= 0 + case token.GTR: + return x.Cmp(y) > 0 + case token.GEQ: + return x.Cmp(y) >= 0 + } + panic("unreachable") +} + + +func binaryFloatOp(x *big.Rat, op token.Token, y *big.Rat) interface{} { + var z big.Rat + switch op { + case token.ADD: + return z.Add(x, y) + case token.SUB: + return z.Sub(x, y) + case token.MUL: + return z.Mul(x, y) + case token.QUO: + return z.Quo(x, y) + case token.EQL: + return x.Cmp(y) == 0 + case token.NEQ: + return x.Cmp(y) != 0 + case token.LSS: + return x.Cmp(y) < 0 + case token.LEQ: + return x.Cmp(y) <= 0 + case token.GTR: + return x.Cmp(y) > 0 + case token.GEQ: + return x.Cmp(y) >= 0 + } + panic("unreachable") +} + + +func binaryCmplxOp(x cmplx, op token.Token, y cmplx) interface{} { + a, b := x.re, x.im + c, d := y.re, y.im + switch op { + case token.ADD: + // (a+c) + i(b+d) + var re, im big.Rat + re.Add(a, c) + im.Add(b, d) + return cmplx{&re, &im} + case token.SUB: + // (a-c) + i(b-d) + var re, im big.Rat + re.Sub(a, c) + im.Sub(b, d) + return cmplx{&re, &im} + case token.MUL: + // (ac-bd) + i(bc+ad) + var ac, bd, bc, ad big.Rat + ac.Mul(a, c) + bd.Mul(b, d) + bc.Mul(b, c) + ad.Mul(a, d) + var re, im big.Rat + re.Sub(&ac, &bd) + im.Add(&bc, &ad) + return cmplx{&re, &im} + case token.QUO: + // (ac+bd)/s + i(bc-ad)/s, with s = cc + dd + var ac, bd, bc, ad, s big.Rat + ac.Mul(a, c) + bd.Mul(b, d) + bc.Mul(b, c) + ad.Mul(a, d) + s.Add(c.Mul(c, c), d.Mul(d, d)) + var re, im big.Rat + re.Add(&ac, &bd) + re.Quo(&re, &s) + im.Sub(&bc, &ad) + im.Quo(&im, &s) + return cmplx{&re, &im} + case token.EQL: + return a.Cmp(c) == 0 && b.Cmp(d) == 0 + case token.NEQ: + return a.Cmp(c) != 0 || b.Cmp(d) != 0 + } + panic("unreachable") +} + + +func binaryStringOp(x string, op token.Token, y string) interface{} { + switch op { + case token.ADD: + return x + y + case token.EQL: + return x == y + case token.NEQ: + return x != y + case token.LSS: + return x < y + case token.LEQ: + return x <= y + case token.GTR: + return x > y + case token.GEQ: + return x >= y + } + panic("unreachable") +} diff --git a/libgo/go/go/types/exportdata.go b/libgo/go/go/types/exportdata.go new file mode 100644 index 00000000000..cb08ffe18a2 --- /dev/null +++ b/libgo/go/go/types/exportdata.go @@ -0,0 +1,135 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements ExportData. + +package types + +import ( + "bufio" + "fmt" + "io" + "os" + "strconv" + "strings" +) + + +func readGopackHeader(buf *bufio.Reader) (name string, size int, err os.Error) { + // See $GOROOT/include/ar.h. + hdr := make([]byte, 64+12+6+6+8+10+2) + _, err = io.ReadFull(buf, hdr) + if err != nil { + return + } + if trace { + fmt.Printf("header: %s", hdr) + } + s := strings.TrimSpace(string(hdr[64+12+6+6+8:][:10])) + size, err = strconv.Atoi(s) + if err != nil || hdr[len(hdr)-2] != '`' || hdr[len(hdr)-1] != '\n' { + err = os.ErrorString("invalid archive header") + return + } + name = strings.TrimSpace(string(hdr[:64])) + return +} + + +type dataReader struct { + *bufio.Reader + io.Closer +} + + +// ExportData returns a readCloser positioned at the beginning of the +// export data section of the given object/archive file, or an error. +// It is the caller's responsibility to close the readCloser. +// +func ExportData(filename string) (rc io.ReadCloser, err os.Error) { + file, err := os.Open(filename) + if err != nil { + return + } + + defer func() { + if err != nil { + file.Close() + // Add file name to error. + err = fmt.Errorf("reading export data: %s: %v", filename, err) + } + }() + + buf := bufio.NewReader(file) + + // Read first line to make sure this is an object file. + line, err := buf.ReadSlice('\n') + if err != nil { + return + } + if string(line) == "!\n" { + // Archive file. Scan to __.PKGDEF, which should + // be second archive entry. + var name string + var size int + + // First entry should be __.SYMDEF. + // Read and discard. + if name, size, err = readGopackHeader(buf); err != nil { + return + } + if name != "__.SYMDEF" { + err = os.ErrorString("go archive does not begin with __.SYMDEF") + return + } + const block = 4096 + tmp := make([]byte, block) + for size > 0 { + n := size + if n > block { + n = block + } + _, err = io.ReadFull(buf, tmp[:n]) + if err != nil { + return + } + size -= n + } + + // Second entry should be __.PKGDEF. + if name, size, err = readGopackHeader(buf); err != nil { + return + } + if name != "__.PKGDEF" { + err = os.ErrorString("go archive is missing __.PKGDEF") + return + } + + // Read first line of __.PKGDEF data, so that line + // is once again the first line of the input. + line, err = buf.ReadSlice('\n') + if err != nil { + return + } + } + + // Now at __.PKGDEF in archive or still at beginning of file. + // Either way, line should begin with "go object ". + if !strings.HasPrefix(string(line), "go object ") { + err = os.ErrorString("not a go object file") + return + } + + // Skip over object header to export data. + // Begins after first line with $$. + for line[0] != '$' { + line, err = buf.ReadSlice('\n') + if err != nil { + return + } + } + + rc = &dataReader{buf, file} + return +} diff --git a/libgo/go/go/types/gcimporter.go b/libgo/go/go/types/gcimporter.go new file mode 100644 index 00000000000..30adc04e729 --- /dev/null +++ b/libgo/go/go/types/gcimporter.go @@ -0,0 +1,792 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements an ast.Importer for gc generated object files. +// TODO(gri) Eventually move this into a separate package outside types. + +package types + +import ( + "big" + "fmt" + "go/ast" + "go/token" + "io" + "os" + "path/filepath" + "runtime" + "scanner" + "strconv" +) + + +const trace = false // set to true for debugging + +var ( + pkgRoot = filepath.Join(runtime.GOROOT(), "pkg", runtime.GOOS+"_"+runtime.GOARCH) + pkgExts = [...]string{".a", ".5", ".6", ".8"} +) + + +// findPkg returns the filename and package id for an import path. +// If no file was found, an empty filename is returned. +func findPkg(path string) (filename, id string) { + if len(path) == 0 { + return + } + + id = path + var noext string + switch path[0] { + default: + // "x" -> "$GOROOT/pkg/$GOOS_$GOARCH/x.ext", "x" + noext = filepath.Join(pkgRoot, path) + + case '.': + // "./x" -> "/this/directory/x.ext", "/this/directory/x" + cwd, err := os.Getwd() + if err != nil { + return + } + noext = filepath.Join(cwd, path) + id = noext + + case '/': + // "/x" -> "/x.ext", "/x" + noext = path + } + + // try extensions + for _, ext := range pkgExts { + filename = noext + ext + if f, err := os.Stat(filename); err == nil && f.IsRegular() { + return + } + } + + filename = "" // not found + return +} + + +// gcParser parses the exports inside a gc compiler-produced +// object/archive file and populates its scope with the results. +type gcParser struct { + scanner scanner.Scanner + tok int // current token + lit string // literal string; only valid for Ident, Int, String tokens + id string // package id of imported package + scope *ast.Scope // scope of imported package; alias for deps[id] + deps map[string]*ast.Scope // package id -> package scope +} + + +func (p *gcParser) init(filename, id string, src io.Reader) { + p.scanner.Init(src) + p.scanner.Error = func(_ *scanner.Scanner, msg string) { p.error(msg) } + p.scanner.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanStrings | scanner.ScanComments | scanner.SkipComments + p.scanner.Whitespace = 1<<'\t' | 1<<' ' + p.scanner.Filename = filename // for good error messages + p.next() + p.id = id + p.scope = ast.NewScope(nil) + p.deps = map[string]*ast.Scope{"unsafe": Unsafe, id: p.scope} +} + + +func (p *gcParser) next() { + p.tok = p.scanner.Scan() + switch p.tok { + case scanner.Ident, scanner.Int, scanner.String: + p.lit = p.scanner.TokenText() + default: + p.lit = "" + } + if trace { + fmt.Printf("%s: %q -> %q\n", scanner.TokenString(p.tok), p.scanner.TokenText(), p.lit) + } +} + + +// GcImporter implements the ast.Importer signature. +func GcImporter(path string) (name string, scope *ast.Scope, err os.Error) { + if path == "unsafe" { + return path, Unsafe, nil + } + + defer func() { + if r := recover(); r != nil { + err = r.(importError) // will re-panic if r is not an importError + if trace { + panic(err) // force a stack trace + } + } + }() + + filename, id := findPkg(path) + if filename == "" { + err = os.ErrorString("can't find import: " + id) + return + } + + buf, err := ExportData(filename) + if err != nil { + return + } + defer buf.Close() + + if trace { + fmt.Printf("importing %s\n", filename) + } + + var p gcParser + p.init(filename, id, buf) + name, scope = p.parseExport() + + return +} + + +// ---------------------------------------------------------------------------- +// Error handling + +// Internal errors are boxed as importErrors. +type importError struct { + pos scanner.Position + err os.Error +} + + +func (e importError) String() string { + return fmt.Sprintf("import error %s (byte offset = %d): %s", e.pos, e.pos.Offset, e.err) +} + + +func (p *gcParser) error(err interface{}) { + if s, ok := err.(string); ok { + err = os.ErrorString(s) + } + // panic with a runtime.Error if err is not an os.Error + panic(importError{p.scanner.Pos(), err.(os.Error)}) +} + + +func (p *gcParser) errorf(format string, args ...interface{}) { + p.error(fmt.Sprintf(format, args...)) +} + + +func (p *gcParser) expect(tok int) string { + lit := p.lit + if p.tok != tok { + p.errorf("expected %q, got %q (%q)", scanner.TokenString(tok), scanner.TokenString(p.tok), lit) + } + p.next() + return lit +} + + +func (p *gcParser) expectSpecial(tok string) { + sep := 'x' // not white space + i := 0 + for i < len(tok) && p.tok == int(tok[i]) && sep > ' ' { + sep = p.scanner.Peek() // if sep <= ' ', there is white space before the next token + p.next() + i++ + } + if i < len(tok) { + p.errorf("expected %q, got %q", tok, tok[0:i]) + } +} + + +func (p *gcParser) expectKeyword(keyword string) { + lit := p.expect(scanner.Ident) + if lit != keyword { + p.errorf("expected keyword %s, got %q", keyword, lit) + } +} + + +// ---------------------------------------------------------------------------- +// Import declarations + +// ImportPath = string_lit . +// +func (p *gcParser) parsePkgId() *ast.Scope { + id, err := strconv.Unquote(p.expect(scanner.String)) + if err != nil { + p.error(err) + } + + scope := p.scope // id == "" stands for the imported package id + if id != "" { + if scope = p.deps[id]; scope == nil { + scope = ast.NewScope(nil) + p.deps[id] = scope + } + } + + return scope +} + + +// dotIdentifier = ( ident | '·' ) { ident | int | '·' } . +func (p *gcParser) parseDotIdent() string { + ident := "" + if p.tok != scanner.Int { + sep := 'x' // not white space + for (p.tok == scanner.Ident || p.tok == scanner.Int || p.tok == '·') && sep > ' ' { + ident += p.lit + sep = p.scanner.Peek() // if sep <= ' ', there is white space before the next token + p.next() + } + } + if ident == "" { + p.expect(scanner.Ident) // use expect() for error handling + } + return ident +} + + +// ExportedName = ImportPath "." dotIdentifier . +// +func (p *gcParser) parseExportedName(kind ast.ObjKind) *ast.Object { + scope := p.parsePkgId() + p.expect('.') + name := p.parseDotIdent() + + // a type may have been declared before - if it exists + // already in the respective package scope, return that + // type + if kind == ast.Typ { + if obj := scope.Lookup(name); obj != nil { + assert(obj.Kind == ast.Typ) + return obj + } + } + + // any other object must be a newly declared object - + // create it and insert it into the package scope + obj := ast.NewObj(kind, name) + if scope.Insert(obj) != nil { + p.errorf("already declared: %s", obj.Name) + } + + // a new type object is a named type and may be referred + // to before the underlying type is known - set it up + if kind == ast.Typ { + obj.Type = &Name{Obj: obj} + } + + return obj +} + + +// ---------------------------------------------------------------------------- +// Types + +// BasicType = identifier . +// +func (p *gcParser) parseBasicType() Type { + obj := Universe.Lookup(p.expect(scanner.Ident)) + if obj == nil || obj.Kind != ast.Typ { + p.errorf("not a basic type: %s", obj.Name) + } + return obj.Type.(Type) +} + + +// ArrayType = "[" int_lit "]" Type . +// +func (p *gcParser) parseArrayType() Type { + // "[" already consumed and lookahead known not to be "]" + lit := p.expect(scanner.Int) + p.expect(']') + elt := p.parseType() + n, err := strconv.Atoui64(lit) + if err != nil { + p.error(err) + } + return &Array{Len: n, Elt: elt} +} + + +// MapType = "map" "[" Type "]" Type . +// +func (p *gcParser) parseMapType() Type { + p.expectKeyword("map") + p.expect('[') + key := p.parseType() + p.expect(']') + elt := p.parseType() + return &Map{Key: key, Elt: elt} +} + + +// Name = identifier | "?" . +// +func (p *gcParser) parseName() (name string) { + switch p.tok { + case scanner.Ident: + name = p.lit + p.next() + case '?': + // anonymous + p.next() + default: + p.error("name expected") + } + return +} + + +// Field = Name Type [ ":" string_lit ] . +// +func (p *gcParser) parseField(scope *ast.Scope) { + // TODO(gri) The code below is not correct for anonymous fields: + // The name is the type name; it should not be empty. + name := p.parseName() + ftyp := p.parseType() + if name == "" { + // anonymous field - ftyp must be T or *T and T must be a type name + ftyp = Deref(ftyp) + if ftyp, ok := ftyp.(*Name); ok { + name = ftyp.Obj.Name + } else { + p.errorf("anonymous field expected") + } + } + if p.tok == ':' { + p.next() + tag := p.expect(scanner.String) + _ = tag // TODO(gri) store tag somewhere + } + fld := ast.NewObj(ast.Var, name) + fld.Type = ftyp + scope.Insert(fld) +} + + +// StructType = "struct" "{" [ FieldList ] "}" . +// FieldList = Field { ";" Field } . +// +func (p *gcParser) parseStructType() Type { + p.expectKeyword("struct") + p.expect('{') + scope := ast.NewScope(nil) + if p.tok != '}' { + p.parseField(scope) + for p.tok == ';' { + p.next() + p.parseField(scope) + } + } + p.expect('}') + return &Struct{} +} + + +// Parameter = ( identifier | "?" ) [ "..." ] Type . +// +func (p *gcParser) parseParameter(scope *ast.Scope, isVariadic *bool) { + name := p.parseName() + if name == "" { + name = "_" // cannot access unnamed identifiers + } + if isVariadic != nil { + if *isVariadic { + p.error("... not on final argument") + } + if p.tok == '.' { + p.expectSpecial("...") + *isVariadic = true + } + } + ptyp := p.parseType() + par := ast.NewObj(ast.Var, name) + par.Type = ptyp + scope.Insert(par) +} + + +// Parameters = "(" [ ParameterList ] ")" . +// ParameterList = { Parameter "," } Parameter . +// +func (p *gcParser) parseParameters(scope *ast.Scope, isVariadic *bool) { + p.expect('(') + if p.tok != ')' { + p.parseParameter(scope, isVariadic) + for p.tok == ',' { + p.next() + p.parseParameter(scope, isVariadic) + } + } + p.expect(')') +} + + +// Signature = Parameters [ Result ] . +// Result = Type | Parameters . +// +func (p *gcParser) parseSignature(scope *ast.Scope, isVariadic *bool) { + p.parseParameters(scope, isVariadic) + + // optional result type + switch p.tok { + case scanner.Ident, scanner.String, '[', '*', '<': + // single, unnamed result + result := ast.NewObj(ast.Var, "_") + result.Type = p.parseType() + scope.Insert(result) + case '(': + // named or multiple result(s) + p.parseParameters(scope, nil) + } +} + + +// FuncType = "func" Signature . +// +func (p *gcParser) parseFuncType() Type { + // "func" already consumed + scope := ast.NewScope(nil) + isVariadic := false + p.parseSignature(scope, &isVariadic) + return &Func{IsVariadic: isVariadic} +} + + +// MethodSpec = identifier Signature . +// +func (p *gcParser) parseMethodSpec(scope *ast.Scope) { + if p.tok == scanner.Ident { + p.expect(scanner.Ident) + } else { + p.parsePkgId() + p.expect('.') + p.parseDotIdent() + } + isVariadic := false + p.parseSignature(scope, &isVariadic) +} + + +// InterfaceType = "interface" "{" [ MethodList ] "}" . +// MethodList = MethodSpec { ";" MethodSpec } . +// +func (p *gcParser) parseInterfaceType() Type { + p.expectKeyword("interface") + p.expect('{') + scope := ast.NewScope(nil) + if p.tok != '}' { + p.parseMethodSpec(scope) + for p.tok == ';' { + p.next() + p.parseMethodSpec(scope) + } + } + p.expect('}') + return &Interface{} +} + + +// ChanType = ( "chan" [ "<-" ] | "<-" "chan" ) Type . +// +func (p *gcParser) parseChanType() Type { + dir := ast.SEND | ast.RECV + if p.tok == scanner.Ident { + p.expectKeyword("chan") + if p.tok == '<' { + p.expectSpecial("<-") + dir = ast.SEND + } + } else { + p.expectSpecial("<-") + p.expectKeyword("chan") + dir = ast.RECV + } + elt := p.parseType() + return &Chan{Dir: dir, Elt: elt} +} + + +// Type = +// BasicType | TypeName | ArrayType | SliceType | StructType | +// PointerType | FuncType | InterfaceType | MapType | ChanType | +// "(" Type ")" . +// BasicType = ident . +// TypeName = ExportedName . +// SliceType = "[" "]" Type . +// PointerType = "*" Type . +// +func (p *gcParser) parseType() Type { + switch p.tok { + case scanner.Ident: + switch p.lit { + default: + return p.parseBasicType() + case "struct": + return p.parseStructType() + case "func": + p.next() // parseFuncType assumes "func" is already consumed + return p.parseFuncType() + case "interface": + return p.parseInterfaceType() + case "map": + return p.parseMapType() + case "chan": + return p.parseChanType() + } + case scanner.String: + // TypeName + return p.parseExportedName(ast.Typ).Type.(Type) + case '[': + p.next() // look ahead + if p.tok == ']' { + // SliceType + p.next() + return &Slice{Elt: p.parseType()} + } + return p.parseArrayType() + case '*': + // PointerType + p.next() + return &Pointer{Base: p.parseType()} + case '<': + return p.parseChanType() + case '(': + // "(" Type ")" + p.next() + typ := p.parseType() + p.expect(')') + return typ + } + p.errorf("expected type, got %s (%q)", scanner.TokenString(p.tok), p.lit) + return nil +} + + +// ---------------------------------------------------------------------------- +// Declarations + +// ImportDecl = "import" identifier string_lit . +// +func (p *gcParser) parseImportDecl() { + p.expectKeyword("import") + // The identifier has no semantic meaning in the import data. + // It exists so that error messages can print the real package + // name: binary.ByteOrder instead of "encoding/binary".ByteOrder. + // TODO(gri): Save package id -> package name mapping. + p.expect(scanner.Ident) + p.parsePkgId() +} + + +// int_lit = [ "+" | "-" ] { "0" ... "9" } . +// +func (p *gcParser) parseInt() (sign, val string) { + switch p.tok { + case '-': + p.next() + sign = "-" + case '+': + p.next() + } + val = p.expect(scanner.Int) + return +} + + +// number = int_lit [ "p" int_lit ] . +// +func (p *gcParser) parseNumber() Const { + // mantissa + sign, val := p.parseInt() + mant, ok := new(big.Int).SetString(sign+val, 10) + assert(ok) + + if p.lit == "p" { + // exponent (base 2) + p.next() + sign, val = p.parseInt() + exp, err := strconv.Atoui(val) + if err != nil { + p.error(err) + } + if sign == "-" { + denom := big.NewInt(1) + denom.Lsh(denom, exp) + return Const{new(big.Rat).SetFrac(mant, denom)} + } + if exp > 0 { + mant.Lsh(mant, exp) + } + return Const{new(big.Rat).SetInt(mant)} + } + + return Const{mant} +} + + +// ConstDecl = "const" ExportedName [ Type ] "=" Literal . +// Literal = bool_lit | int_lit | float_lit | complex_lit | string_lit . +// bool_lit = "true" | "false" . +// complex_lit = "(" float_lit "+" float_lit ")" . +// string_lit = `"` { unicode_char } `"` . +// +func (p *gcParser) parseConstDecl() { + p.expectKeyword("const") + obj := p.parseExportedName(ast.Con) + var x Const + var typ Type + if p.tok != '=' { + obj.Type = p.parseType() + } + p.expect('=') + switch p.tok { + case scanner.Ident: + // bool_lit + if p.lit != "true" && p.lit != "false" { + p.error("expected true or false") + } + x = Const{p.lit == "true"} + typ = Bool.Underlying + p.next() + case '-', scanner.Int: + // int_lit + x = p.parseNumber() + typ = Int.Underlying + if _, ok := x.val.(*big.Rat); ok { + typ = Float64.Underlying + } + case '(': + // complex_lit + p.next() + re := p.parseNumber() + p.expect('+') + im := p.parseNumber() + p.expect(')') + x = Const{cmplx{re.val.(*big.Rat), im.val.(*big.Rat)}} + typ = Complex128.Underlying + case scanner.String: + // string_lit + x = MakeConst(token.STRING, p.lit) + p.next() + typ = String.Underlying + default: + p.error("expected literal") + } + if obj.Type == nil { + obj.Type = typ + } + _ = x // TODO(gri) store x somewhere +} + + +// TypeDecl = "type" ExportedName Type . +// +func (p *gcParser) parseTypeDecl() { + p.expectKeyword("type") + obj := p.parseExportedName(ast.Typ) + typ := p.parseType() + + name := obj.Type.(*Name) + assert(name.Underlying == nil) + assert(Underlying(typ) == typ) + name.Underlying = typ +} + + +// VarDecl = "var" ExportedName Type . +// +func (p *gcParser) parseVarDecl() { + p.expectKeyword("var") + obj := p.parseExportedName(ast.Var) + obj.Type = p.parseType() +} + + +// FuncDecl = "func" ExportedName Signature . +// +func (p *gcParser) parseFuncDecl() { + // "func" already consumed + obj := p.parseExportedName(ast.Fun) + obj.Type = p.parseFuncType() +} + + +// MethodDecl = "func" Receiver identifier Signature . +// Receiver = "(" ( identifier | "?" ) [ "*" ] ExportedName ")" . +// +func (p *gcParser) parseMethodDecl() { + // "func" already consumed + scope := ast.NewScope(nil) // method scope + p.expect('(') + p.parseParameter(scope, nil) // receiver + p.expect(')') + p.expect(scanner.Ident) + isVariadic := false + p.parseSignature(scope, &isVariadic) + +} + + +// Decl = [ ImportDecl | ConstDecl | TypeDecl | VarDecl | FuncDecl | MethodDecl ] "\n" . +// +func (p *gcParser) parseDecl() { + switch p.lit { + case "import": + p.parseImportDecl() + case "const": + p.parseConstDecl() + case "type": + p.parseTypeDecl() + case "var": + p.parseVarDecl() + case "func": + p.next() // look ahead + if p.tok == '(' { + p.parseMethodDecl() + } else { + p.parseFuncDecl() + } + } + p.expect('\n') +} + + +// ---------------------------------------------------------------------------- +// Export + +// Export = "PackageClause { Decl } "$$" . +// PackageClause = "package" identifier [ "safe" ] "\n" . +// +func (p *gcParser) parseExport() (string, *ast.Scope) { + p.expectKeyword("package") + name := p.expect(scanner.Ident) + if p.tok != '\n' { + // A package is safe if it was compiled with the -u flag, + // which disables the unsafe package. + // TODO(gri) remember "safe" package + p.expectKeyword("safe") + } + p.expect('\n') + + for p.tok != '$' && p.tok != scanner.EOF { + p.parseDecl() + } + + if ch := p.scanner.Peek(); p.tok != '$' || ch != '$' { + // don't call next()/expect() since reading past the + // export data may cause scanner errors (e.g. NUL chars) + p.errorf("expected '$$', got %s %c", scanner.TokenString(p.tok), ch) + } + + if n := p.scanner.ErrorCount; n != 0 { + p.errorf("expected no scanner errors, got %d", n) + } + + return name, p.scope +} diff --git a/libgo/go/go/types/gcimporter_test.go b/libgo/go/go/types/gcimporter_test.go new file mode 100644 index 00000000000..556e761df2d --- /dev/null +++ b/libgo/go/go/types/gcimporter_test.go @@ -0,0 +1,111 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package types + +import ( + "exec" + "io/ioutil" + "path/filepath" + "runtime" + "strings" + "testing" + "time" +) + + +var gcName, gcPath string // compiler name and path + +func init() { + // determine compiler + switch runtime.GOARCH { + case "386": + gcName = "8g" + case "amd64": + gcName = "6g" + case "arm": + gcName = "5g" + default: + gcName = "unknown-GOARCH-compiler" + gcPath = gcName + return + } + gcPath, _ = exec.LookPath(gcName) +} + + +func compile(t *testing.T, dirname, filename string) { + cmd, err := exec.Run(gcPath, []string{gcPath, filename}, nil, dirname, exec.DevNull, exec.Pipe, exec.MergeWithStdout) + if err != nil { + t.Errorf("%s %s failed: %s", gcName, filename, err) + return + } + defer cmd.Close() + + msg, err := cmd.Wait(0) + if err != nil { + t.Errorf("%s %s failed: %s", gcName, filename, err) + return + } + + if !msg.Exited() || msg.ExitStatus() != 0 { + t.Errorf("%s %s failed: exit status = %d", gcName, filename, msg.ExitStatus()) + output, _ := ioutil.ReadAll(cmd.Stdout) + t.Log(string(output)) + } +} + + +func testPath(t *testing.T, path string) bool { + _, _, err := GcImporter(path) + if err != nil { + t.Errorf("testPath(%s): %s", path, err) + return false + } + return true +} + + +const maxTime = 3e9 // maximum allotted testing time in ns + +func testDir(t *testing.T, dir string, endTime int64) (nimports int) { + dirname := filepath.Join(pkgRoot, dir) + list, err := ioutil.ReadDir(dirname) + if err != nil { + t.Errorf("testDir(%s): %s", dirname, err) + } + for _, f := range list { + if time.Nanoseconds() >= endTime { + t.Log("testing time used up") + return + } + switch { + case f.IsRegular(): + // try extensions + for _, ext := range pkgExts { + if strings.HasSuffix(f.Name, ext) { + name := f.Name[0 : len(f.Name)-len(ext)] // remove extension + if testPath(t, filepath.Join(dir, name)) { + nimports++ + } + } + } + case f.IsDirectory(): + nimports += testDir(t, filepath.Join(dir, f.Name), endTime) + } + } + return +} + + +func TestGcImport(t *testing.T) { + compile(t, "testdata", "exports.go") + + nimports := 0 + if testPath(t, "./testdata/exports") { + nimports++ + } + nimports += testDir(t, "", time.Nanoseconds()+maxTime) // installed packages + t.Logf("tested %d imports", nimports) +} diff --git a/libgo/go/go/types/testdata/exports.go b/libgo/go/go/types/testdata/exports.go new file mode 100644 index 00000000000..13efe012a0b --- /dev/null +++ b/libgo/go/go/types/testdata/exports.go @@ -0,0 +1,89 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is used to generate a .6 object file which +// serves as test file for gcimporter_test.go. + +package exports + +import ( + "go/ast" +) + + +const ( + C0 int = 0 + C1 = 3.14159265 + C2 = 2.718281828i + C3 = -123.456e-789 + C4 = +123.456E+789 + C5 = 1234i + C6 = "foo\n" + C7 = `bar\n` +) + + +type ( + T1 int + T2 [10]int + T3 []int + T4 *int + T5 chan int + T6a chan<- int + T6b chan (<-chan int) + T6c chan<- (chan int) + T7 <-chan *ast.File + T8 struct{} + T9 struct { + a int + b, c float32 + d []string "tag" + } + T10 struct { + T8 + T9 + _ *T10 + } + T11 map[int]string + T12 interface{} + T13 interface { + m1() + m2(int) float32 + } + T14 interface { + T12 + T13 + m3(x ...struct{}) []T9 + } + T15 func() + T16 func(int) + T17 func(x int) + T18 func() float32 + T19 func() (x float32) + T20 func(...interface{}) + T21 struct{ next *T21 } + T22 struct{ link *T23 } + T23 struct{ link *T22 } + T24 *T24 + T25 *T26 + T26 *T27 + T27 *T25 + T28 func(T28) T28 +) + + +var ( + V0 int + V1 = -991.0 +) + + +func F1() {} +func F2(x int) {} +func F3() int { return 0 } +func F4() float32 { return 0 } +func F5(a, b, c int, u, v, w struct{ x, y T1 }, more ...interface{}) (p, q, r chan<- T10) + + +func (p *T1) M1() diff --git a/libgo/go/go/types/types.go b/libgo/go/go/types/types.go new file mode 100644 index 00000000000..2ee645d989b --- /dev/null +++ b/libgo/go/go/types/types.go @@ -0,0 +1,122 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// PACKAGE UNDER CONSTRUCTION. ANY AND ALL PARTS MAY CHANGE. +// Package types declares the types used to represent Go types. +// +package types + +import "go/ast" + + +// All types implement the Type interface. +type Type interface { + isType() +} + + +// All concrete types embed ImplementsType which +// ensures that all types implement the Type interface. +type ImplementsType struct{} + +func (t *ImplementsType) isType() {} + + +// A Basic represents a (unnamed) basic type. +type Basic struct { + ImplementsType + // TODO(gri) need a field specifying the exact basic type +} + + +// An Array represents an array type [Len]Elt. +type Array struct { + ImplementsType + Len uint64 + Elt Type +} + + +// A Slice represents a slice type []Elt. +type Slice struct { + ImplementsType + Elt Type +} + + +// A Struct represents a struct type struct{...}. +type Struct struct { + ImplementsType + // TODO(gri) need to remember fields. +} + + +// A Pointer represents a pointer type *Base. +type Pointer struct { + ImplementsType + Base Type +} + + +// A Func represents a function type func(...) (...). +type Func struct { + ImplementsType + IsVariadic bool + // TODO(gri) need to remember parameters. +} + + +// An Interface represents an interface type interface{...}. +type Interface struct { + ImplementsType + // TODO(gri) need to remember methods. +} + + +// A Map represents a map type map[Key]Elt. +type Map struct { + ImplementsType + Key, Elt Type +} + + +// A Chan represents a channel type chan Elt, <-chan Elt, or chan<-Elt. +type Chan struct { + ImplementsType + Dir ast.ChanDir + Elt Type +} + + +// A Name represents a named type as declared in a type declaration. +type Name struct { + ImplementsType + Underlying Type // nil if not fully declared + Obj *ast.Object // corresponding declared object + // TODO(gri) need to remember fields and methods. +} + + +// If typ is a pointer type, Deref returns the pointer's base type; +// otherwise it returns typ. +func Deref(typ Type) Type { + if typ, ok := typ.(*Pointer); ok { + return typ.Base + } + return typ +} + + +// Underlying returns the underlying type of a type. +func Underlying(typ Type) Type { + if typ, ok := typ.(*Name); ok { + utyp := typ.Underlying + if _, ok := utyp.(*Basic); ok { + return typ + } + return utyp + + } + return typ +} diff --git a/libgo/go/go/types/universe.go b/libgo/go/go/types/universe.go new file mode 100644 index 00000000000..2a54a8ac12c --- /dev/null +++ b/libgo/go/go/types/universe.go @@ -0,0 +1,113 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// FILE UNDER CONSTRUCTION. ANY AND ALL PARTS MAY CHANGE. +// This file implements the universe and unsafe package scopes. + +package types + +import "go/ast" + + +var ( + scope, // current scope to use for initialization + Universe, + Unsafe *ast.Scope +) + + +func define(kind ast.ObjKind, name string) *ast.Object { + obj := ast.NewObj(kind, name) + if scope.Insert(obj) != nil { + panic("types internal error: double declaration") + } + return obj +} + + +func defType(name string) *Name { + obj := define(ast.Typ, name) + typ := &Name{Underlying: &Basic{}, Obj: obj} + obj.Type = typ + return typ +} + + +func defConst(name string) { + obj := define(ast.Con, name) + _ = obj // TODO(gri) fill in other properties +} + + +func defFun(name string) { + obj := define(ast.Fun, name) + _ = obj // TODO(gri) fill in other properties +} + + +var ( + Bool, + Int, + Float64, + Complex128, + String *Name +) + + +func init() { + Universe = ast.NewScope(nil) + scope = Universe + + Bool = defType("bool") + defType("byte") // TODO(gri) should be an alias for uint8 + defType("complex64") + Complex128 = defType("complex128") + defType("float32") + Float64 = defType("float64") + defType("int8") + defType("int16") + defType("int32") + defType("int64") + String = defType("string") + defType("uint8") + defType("uint16") + defType("uint32") + defType("uint64") + Int = defType("int") + defType("uint") + defType("uintptr") + + defConst("true") + defConst("false") + defConst("iota") + defConst("nil") + + defFun("append") + defFun("cap") + defFun("close") + defFun("complex") + defFun("copy") + defFun("imag") + defFun("len") + defFun("make") + defFun("new") + defFun("panic") + defFun("print") + defFun("println") + defFun("real") + defFun("recover") + + Unsafe = ast.NewScope(nil) + scope = Unsafe + defType("Pointer") + + defFun("Alignof") + defFun("New") + defFun("NewArray") + defFun("Offsetof") + defFun("Reflect") + defFun("Sizeof") + defFun("Typeof") + defFun("Unreflect") +} diff --git a/libgo/go/gob/codec_test.go b/libgo/go/gob/codec_test.go index 28042ccaa3a..8961336cd34 100644 --- a/libgo/go/gob/codec_test.go +++ b/libgo/go/gob/codec_test.go @@ -999,13 +999,12 @@ type Bad0 struct { C float64 } - func TestInvalidField(t *testing.T) { var bad0 Bad0 bad0.CH = make(chan int) b := new(bytes.Buffer) dummyEncoder := new(Encoder) // sufficient for this purpose. - dummyEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) + dummyEncoder.encode(b, reflect.ValueOf(&bad0), userType(reflect.TypeOf(&bad0))) if err := dummyEncoder.err; err == nil { t.Error("expected error; got none") } else if strings.Index(err.String(), "type") < 0 { diff --git a/libgo/go/gob/decode.go b/libgo/go/gob/decode.go index f8159d4ea32..0e86df6b57a 100644 --- a/libgo/go/gob/decode.go +++ b/libgo/go/gob/decode.go @@ -406,7 +406,7 @@ func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { func decString(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { - *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)) + *(*unsafe.Pointer)(p) = unsafe.Pointer(new(string)) } p = *(*unsafe.Pointer)(p) } @@ -468,7 +468,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) basep := p delta := int(state.decodeUint()) if delta != 0 { - errorf("gob decode: corrupted data: non-zero delta for singleton") + errorf("decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] ptr := unsafe.Pointer(basep) // offset will be zero @@ -486,14 +486,14 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) // This state cannot arise for decodeSingle, which is called directly // from the user's value, not from the innards of an engine. func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) { - p = allocate(ut.base.(*reflect.StructType), p, indir) + p = allocate(ut.base, p, indir) state := dec.newDecoderState(&dec.buf) state.fieldnum = -1 basep := p for state.b.Len() > 0 { delta := int(state.decodeUint()) if delta < 0 { - errorf("gob decode: corrupted data: negative delta") + errorf("decode: corrupted data: negative delta") } if delta == 0 { // struct terminator is zero delta fieldnum break @@ -521,7 +521,7 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) { for state.b.Len() > 0 { delta := int(state.decodeUint()) if delta < 0 { - errorf("gob ignore decode: corrupted data: negative delta") + errorf("ignore decode: corrupted data: negative delta") } if delta == 0 { // struct terminator is zero delta fieldnum break @@ -544,7 +544,7 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) { state.fieldnum = singletonField delta := int(state.decodeUint()) if delta != 0 { - errorf("gob decode: corrupted data: non-zero delta for singleton") + errorf("decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] instr.op(instr, state, unsafe.Pointer(nil)) @@ -567,12 +567,12 @@ func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp dec // decodeArray decodes an array and stores it through p, that is, p points to the zeroth element. // The length is an unsigned integer preceding the elements. Even though the length is redundant // (it's part of the type), it's a useful check and is included in the encoding. -func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { +func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } if n := state.decodeUint(); n != uint64(length) { - errorf("gob: length mismatch in decodeArray") + errorf("length mismatch in decodeArray") } dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } @@ -581,7 +581,7 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p // unlike the other items we can't use a pointer directly. func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} - up := unsafe.Pointer(v.UnsafeAddr()) + up := unsafe.Pointer(unsafeAddr(v)) if indir > 1 { up = decIndirect(up, indir) } @@ -593,24 +593,24 @@ func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, // Maps are encoded as a length followed by key:value pairs. // Because the internals of maps are not visible to us, we must // use reflection rather than pointer magic. -func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { +func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect } up := unsafe.Pointer(p) if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime // Allocate map. - *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Get()) + *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Pointer()) } // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))).(*reflect.MapValue) + v := reflect.ValueOf(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) n := int(state.decodeUint()) for i := 0; i < n; i++ { - key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) - elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl) - v.SetElem(key, elem) + key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl) + elem := decodeIntoValue(state, elemOp, elemIndir, allocValue(mtyp.Elem()), ovfl) + v.SetMapIndex(key, elem) } } @@ -625,7 +625,7 @@ func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length // ignoreArray discards the data for an array value with no destination. func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) { if n := state.decodeUint(); n != uint64(length) { - errorf("gob: length mismatch in ignoreArray") + errorf("length mismatch in ignoreArray") } dec.ignoreArrayHelper(state, elemOp, length) } @@ -643,7 +643,7 @@ func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) { // decodeSlice decodes a slice and stores the slice header through p. // Slices are encoded as an unsigned length followed by the elements. -func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { +func (dec *Decoder) decodeSlice(atyp reflect.Type, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { n := int(uintptr(state.decodeUint())) if indir > 0 { up := unsafe.Pointer(p) @@ -667,27 +667,21 @@ func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) { dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) } -// setInterfaceValue sets an interface value to a concrete value through -// reflection. If the concrete value does not implement the interface, the -// setting will panic. This routine turns the panic into an error return. -// This dance avoids manually checking that the value satisfies the -// interface. -// TODO(rsc): avoid panic+recover after fixing issue 327. -func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) { - defer func() { - if e := recover(); e != nil { - error(e.(os.Error)) - } - }() +// setInterfaceValue sets an interface value to a concrete value, +// but first it checks that the assignment will succeed. +func setInterfaceValue(ivalue reflect.Value, value reflect.Value) { + if !value.Type().AssignableTo(ivalue.Type()) { + errorf("cannot assign value of type %s to %s", value.Type(), ivalue.Type()) + } ivalue.Set(value) } // decodeInterface decodes an interface value and stores it through p. // Interfaces are encoded as the name of a concrete type followed by a value. // If the name is empty, the value is nil and no value is sent. -func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderState, p uintptr, indir int) { - // Create an interface reflect.Value. We need one even for the nil case. - ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue) +func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p uintptr, indir int) { + // Create a writable interface reflect.Value. We need one even for the nil case. + ivalue := allocValue(ityp) // Read the name of the concrete type. b := make([]byte, state.decodeUint()) state.b.Read(b) @@ -695,13 +689,13 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderS if name == "" { // Copy the representation of the nil interface value to the target. // This is horribly unsafe and special. - *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() + *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.InterfaceData() return } // The concrete type must be registered. typ, ok := nameToConcreteType[name] if !ok { - errorf("gob: name not registered for interface: %q", name) + errorf("name not registered for interface: %q", name) } // Read the type id of the concrete value. concreteId := dec.decodeTypeSequence(true) @@ -712,7 +706,7 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderS // in case we want to ignore the value by skipping it completely). state.decodeUint() // Read the concrete value. - value := reflect.MakeZero(typ) + value := allocValue(typ) dec.decodeValue(concreteId, value) if dec.err != nil { error(dec.err) @@ -726,7 +720,7 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderS setInterfaceValue(ivalue, value) // Copy the representation of the interface value to the target. // This is horribly unsafe and special. - *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() + *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.InterfaceData() } // ignoreInterface discards the data for an interface value with no destination. @@ -823,8 +817,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg if op == nil { inProgress[rt] = &op // Special cases - switch t := typ.(type) { - case *reflect.ArrayType: + switch t := typ; t.Kind() { + case reflect.Array: name = "element of " + name elemId := dec.wireType[wireId].ArrayT.Elem elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) @@ -833,7 +827,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } - case *reflect.MapType: + case reflect.Map: name = "element of " + name keyId := dec.wireType[wireId].MapT.Key elemId := dec.wireType[wireId].MapT.Elem @@ -845,7 +839,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl) } - case *reflect.SliceType: + case reflect.Slice: name = "element of " + name if t.Elem().Kind() == reflect.Uint8 { op = decUint8Array @@ -863,7 +857,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } - case *reflect.StructType: + case reflect.Struct: // Generate a closure that calls out to the engine for the nested type. enginePtr, err := dec.getDecEnginePtr(wireId, userType(typ)) if err != nil { @@ -873,14 +867,14 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg // indirect through enginePtr to delay evaluation for recursive structs. dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir) } - case *reflect.InterfaceType: + case reflect.Interface: op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.decodeInterface(t, state, uintptr(p), i.indir) } } } if op == nil { - errorf("gob: decode can't handle type %s", rt.String()) + errorf("decode can't handle type %s", rt.String()) } return &op, indir } @@ -901,7 +895,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { wire := dec.wireType[wireId] switch { case wire == nil: - errorf("gob: bad data: undefined type %s", wireId.string()) + errorf("bad data: undefined type %s", wireId.string()) case wire.ArrayT != nil: elemId := wire.ArrayT.Elem elemOp := dec.decIgnoreOpFor(elemId) @@ -943,7 +937,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { } } if op == nil { - errorf("gob: bad data: ignore can't handle type %s", wireId.string()) + errorf("bad data: ignore can't handle type %s", wireId.string()) } return op } @@ -951,32 +945,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { // gobDecodeOpFor returns the op for a type that is known to implement // GobDecoder. func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { - rt := ut.user + rcvrType := ut.user if ut.decIndir == -1 { - rt = reflect.PtrTo(rt) + rcvrType = reflect.PtrTo(rcvrType) } else if ut.decIndir > 0 { for i := int8(0); i < ut.decIndir; i++ { - rt = rt.(*reflect.PtrType).Elem() + rcvrType = rcvrType.Elem() } } var op decOp op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { - // Allocate the underlying data, but hold on to the address we have, - // since we need it to get to the receiver's address. - allocate(ut.base, uintptr(p), ut.indir) + // Caller has gotten us to within one indirection of our value. + if i.indir > 0 { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe.New(ut.base) + } + } + // Now p is a pointer to the base type. Do we need to climb out to + // get to the receiver type? var v reflect.Value if ut.decIndir == -1 { - // Need to climb up one level to turn value into pointer. - v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + v = reflect.ValueOf(unsafe.Unreflect(rcvrType, unsafe.Pointer(&p))) } else { - if ut.decIndir > 0 { - p = decIndirect(p, int(ut.decIndir)) - } - v = reflect.NewValue(unsafe.Unreflect(rt, p)) + v = reflect.ValueOf(unsafe.Unreflect(rcvrType, p)) } - state.dec.decodeGobDecoder(state, v, methodIndex(rt, gobDecodeMethodName)) + state.dec.decodeGobDecoder(state, v, methodIndex(rcvrType, gobDecodeMethodName)) } - return &op, int(ut.decIndir) + return &op, int(ut.indir) } @@ -999,37 +994,37 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re if ut.isGobDecoder { // This test trumps all others. return true } - switch t := ut.base.(type) { + switch t := ut.base; t.Kind() { default: // chan, etc: cannot handle. return false - case *reflect.BoolType: + case reflect.Bool: return fw == tBool - case *reflect.IntType: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return fw == tInt - case *reflect.UintType: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return fw == tUint - case *reflect.FloatType: + case reflect.Float32, reflect.Float64: return fw == tFloat - case *reflect.ComplexType: + case reflect.Complex64, reflect.Complex128: return fw == tComplex - case *reflect.StringType: + case reflect.String: return fw == tString - case *reflect.InterfaceType: + case reflect.Interface: return fw == tInterface - case *reflect.ArrayType: + case reflect.Array: if !ok || wire.ArrayT == nil { return false } array := wire.ArrayT return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) - case *reflect.MapType: + case reflect.Map: if !ok || wire.MapT == nil { return false } MapType := wire.MapT return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress) - case *reflect.SliceType: + case reflect.Slice: // Is it an array of bytes? if t.Elem().Kind() == reflect.Uint8 { return fw == tBytes @@ -1043,7 +1038,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re } elem := userType(t.Elem()).base return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) - case *reflect.StructType: + case reflect.Struct: return true } return true @@ -1093,8 +1088,9 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err // it calls out to compileSingle. func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) { rt := ut.base - srt, ok := rt.(*reflect.StructType) - if !ok || ut.isGobDecoder { + srt := rt + if srt.Kind() != reflect.Struct || + ut.isGobDecoder { return dec.compileSingle(remoteId, ut) } var wireStruct *structType @@ -1110,7 +1106,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn wireStruct = wire.StructT } if wireStruct == nil { - errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String()) + errorf("type mismatch in decoder: want struct type %s; got non-struct", rt.String()) } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.Field)) @@ -1119,7 +1115,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { wireField := wireStruct.Field[fieldnum] if wireField.Name == "" { - errorf("gob: empty name for remote field of type %s", wireStruct.Name) + errorf("empty name for remote field of type %s", wireStruct.Name) } ovfl := overflow(wireField.Name) // Find the field of the local type with the same name. @@ -1131,7 +1127,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn continue } if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { - errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) + errorf("wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl} @@ -1163,7 +1159,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePt // emptyStruct is the type we compile into when ignoring a struct value. type emptyStruct struct{} -var emptyStructType = reflect.Typeof(emptyStruct{}) +var emptyStructType = reflect.TypeOf(emptyStruct{}) // getDecEnginePtr returns the engine for the specified type when the value is to be discarded. func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { @@ -1189,31 +1185,27 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { defer catchError(&dec.err) // If the value is nil, it means we should just ignore this item. - if val == nil { + if !val.IsValid() { dec.decodeIgnoredValue(wireId) return } // Dereference down to the underlying struct type. ut := userType(val.Type()) base := ut.base - indir := ut.indir - if ut.isGobDecoder { - indir = int(ut.decIndir) - } var enginePtr **decEngine enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut) if dec.err != nil { return } engine := *enginePtr - if st, ok := base.(*reflect.StructType); ok && !ut.isGobDecoder { + if st := base; st.Kind() == reflect.Struct && !ut.isGobDecoder { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { name := base.Name() - errorf("gob: type mismatch: no fields matched compiling decoder for %s", name) + errorf("type mismatch: no fields matched compiling decoder for %s", name) } - dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir) + dec.decodeStruct(engine, ut, uintptr(unsafeAddr(val)), ut.indir) } else { - dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) + dec.decodeSingle(engine, ut, uintptr(unsafeAddr(val))) } } @@ -1234,7 +1226,7 @@ func (dec *Decoder) decodeIgnoredValue(wireId typeId) { func init() { var iop, uop decOp - switch reflect.Typeof(int(0)).Bits() { + switch reflect.TypeOf(int(0)).Bits() { case 32: iop = decInt32 uop = decUint32 @@ -1248,7 +1240,7 @@ func init() { decOpTable[reflect.Uint] = uop // Finally uintptr - switch reflect.Typeof(uintptr(0)).Bits() { + switch reflect.TypeOf(uintptr(0)).Bits() { case 32: uop = decUint32 case 64: @@ -1258,3 +1250,26 @@ func init() { } decOpTable[reflect.Uintptr] = uop } + +// Gob assumes it can call UnsafeAddr on any Value +// in order to get a pointer it can copy data from. +// Values that have just been created and do not point +// into existing structs or slices cannot be addressed, +// so simulate it by returning a pointer to a copy. +// Each call allocates once. +func unsafeAddr(v reflect.Value) uintptr { + if v.CanAddr() { + return v.UnsafeAddr() + } + x := reflect.New(v.Type()).Elem() + x.Set(v) + return x.UnsafeAddr() +} + +// Gob depends on being able to take the address +// of zeroed Values it creates, so use this wrapper instead +// of the standard reflect.Zero. +// Each call allocates once. +func allocValue(t reflect.Type) reflect.Value { + return reflect.New(t).Elem() +} diff --git a/libgo/go/gob/decoder.go b/libgo/go/gob/decoder.go index 34364161aa3..ea2f62ec503 100644 --- a/libgo/go/gob/decoder.go +++ b/libgo/go/gob/decoder.go @@ -50,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) { // Type: wire := new(wireType) - dec.decodeValue(tWireType, reflect.NewValue(wire)) + dec.decodeValue(tWireType, reflect.ValueOf(wire)) if dec.err != nil { return } @@ -159,9 +159,9 @@ func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId { // data item received, and must be a pointer. func (dec *Decoder) Decode(e interface{}) os.Error { if e == nil { - return dec.DecodeValue(nil) + return dec.DecodeValue(reflect.Value{}) } - value := reflect.NewValue(e) + value := reflect.ValueOf(e) // If e represents a value as opposed to a pointer, the answer won't // get back to the caller. Make sure it's a pointer. if value.Type().Kind() != reflect.Ptr { @@ -171,12 +171,18 @@ func (dec *Decoder) Decode(e interface{}) os.Error { return dec.DecodeValue(value) } -// DecodeValue reads the next value from the connection and stores -// it in the data represented by the reflection value. -// The value must be the correct type for the next -// data item received, or it may be nil, which means the -// value will be discarded. -func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { +// DecodeValue reads the next value from the connection. +// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value. +// Otherwise, it stores the value into v. In that case, v must represent +// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet()) +func (dec *Decoder) DecodeValue(v reflect.Value) os.Error { + if v.IsValid() { + if v.Kind() == reflect.Ptr && !v.IsNil() { + // That's okay, we'll store through the pointer. + } else if !v.CanSet() { + return os.ErrorString("gob: DecodeValue of unassignable value") + } + } // Make sure we're single-threaded through here. dec.mutex.Lock() defer dec.mutex.Unlock() @@ -185,7 +191,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { dec.err = nil id := dec.decodeTypeSequence(false) if dec.err == nil { - dec.decodeValue(id, value) + dec.decodeValue(id, v) } return dec.err } diff --git a/libgo/go/gob/doc.go b/libgo/go/gob/doc.go index 613974a000f..850759bbda6 100644 --- a/libgo/go/gob/doc.go +++ b/libgo/go/gob/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The gob package manages streams of gobs - binary values exchanged between an +Package gob manages streams of gobs - binary values exchanged between an Encoder (transmitter) and a Decoder (receiver). A typical use is transporting arguments and results of remote procedure calls (RPCs) such as those provided by package "rpc". @@ -159,7 +159,7 @@ description, constructed from these types: Elem typeId Len int } - type CommonType { + type CommonType struct { Name string // the name of the struct type Id int // the id of the type, repeated so it's inside the type } diff --git a/libgo/go/gob/encode.go b/libgo/go/gob/encode.go index 5cfdb583a18..f9e691a2fa6 100644 --- a/libgo/go/gob/encode.go +++ b/libgo/go/gob/encode.go @@ -384,7 +384,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui up := unsafe.Pointer(elemp) if elemIndir > 0 { if up = encIndirect(up, elemIndir); up == nil { - errorf("gob: encodeArray: nil element") + errorf("encodeArray: nil element") } elemp = uintptr(up) } @@ -396,27 +396,27 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui // encodeReflectValue is a helper for maps. It encodes the value v. func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { - for i := 0; i < indir && v != nil; i++ { + for i := 0; i < indir && v.IsValid(); i++ { v = reflect.Indirect(v) } - if v == nil { - errorf("gob: encodeReflectValue: nil element") + if !v.IsValid() { + errorf("encodeReflectValue: nil element") } - op(nil, state, unsafe.Pointer(v.UnsafeAddr())) + op(nil, state, unsafe.Pointer(unsafeAddr(v))) } // encodeMap encodes a map as unsigned count followed by key:value pairs. // Because map internals are not exposed, we must use reflection rather than // addresses. -func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { +func (enc *Encoder) encodeMap(b *bytes.Buffer, mv reflect.Value, keyOp, elemOp encOp, keyIndir, elemIndir int) { state := enc.newEncoderState(b) state.fieldnum = -1 state.sendZero = true - keys := mv.Keys() + keys := mv.MapKeys() state.encodeUint(uint64(len(keys))) for _, key := range keys { encodeReflectValue(state, key, keyOp, keyIndir) - encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir) + encodeReflectValue(state, mv.MapIndex(key), elemOp, elemIndir) } enc.freeEncoderState(state) } @@ -426,7 +426,7 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem // by the type identifier (which might require defining that type right now), followed // by the concrete value. A nil value gets sent as the empty string for the name, // followed by no value. -func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) { +func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv reflect.Value) { state := enc.newEncoderState(b) state.fieldnum = -1 state.sendZero = true @@ -438,7 +438,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) ut := userType(iv.Elem().Type()) name, ok := concreteTypeToName[ut.base] if !ok { - errorf("gob: type not registered for interface: %s", ut.base) + errorf("type not registered for interface: %s", ut.base) } // Send the name. state.encodeUint(uint64(len(name))) @@ -525,8 +525,8 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp if op == nil { inProgress[rt] = &op // Special cases - switch t := typ.(type) { - case *reflect.SliceType: + switch t := typ; t.Kind() { + case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { op = encUint8Array break @@ -541,29 +541,29 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp state.update(i) state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len)) } - case *reflect.ArrayType: + case reflect.Array: // True arrays have size in the type. elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len()) } - case *reflect.MapType: + case reflect.Map: keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress) elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) - mv := reflect.Indirect(v).(*reflect.MapValue) + v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) + mv := reflect.Indirect(v) if !state.sendZero && mv.Len() == 0 { return } state.update(i) state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir) } - case *reflect.StructType: + case reflect.Struct: // Generate a closure that calls out to the engine for the nested type. enc.getEncEngine(userType(typ)) info := mustGetTypeInfo(typ) @@ -572,13 +572,13 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp // indirect through info to delay evaluation for recursive structs state.enc.encodeStruct(state.b, info.encoder, uintptr(p)) } - case *reflect.InterfaceType: + case reflect.Interface: op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Interfaces transmit the name and contents of the concrete // value they contain. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) - iv := reflect.Indirect(v).(*reflect.InterfaceValue) - if !state.sendZero && (iv == nil || iv.IsNil()) { + v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) + iv := reflect.Indirect(v) + if !state.sendZero && (!iv.IsValid() || iv.IsNil()) { return } state.update(i) @@ -587,7 +587,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp } } if op == nil { - errorf("gob enc: can't happen: encode type %s", rt.String()) + errorf("can't happen: encode type %s", rt.String()) } return &op, indir } @@ -599,7 +599,7 @@ func methodIndex(rt reflect.Type, method string) int { return i } } - errorf("gob: internal error: can't find method %s", method) + errorf("internal error: can't find method %s", method) return 0 } @@ -611,7 +611,7 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { rt = reflect.PtrTo(rt) } else if ut.encIndir > 0 { for i := int8(0); i < ut.encIndir; i++ { - rt = rt.(*reflect.PtrType).Elem() + rt = rt.Elem() } } var op encOp @@ -619,9 +619,9 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { var v reflect.Value if ut.encIndir == -1 { // Need to climb up one level to turn value into pointer. - v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + v = reflect.ValueOf(unsafe.Unreflect(rt, unsafe.Pointer(&p))) } else { - v = reflect.NewValue(unsafe.Unreflect(rt, p)) + v = reflect.ValueOf(unsafe.Unreflect(rt, p)) } state.update(i) state.enc.encodeGobEncoder(state.b, v, methodIndex(rt, gobEncodeMethodName)) @@ -631,14 +631,15 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { // compileEnc returns the engine to compile the type. func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { - srt, isStruct := ut.base.(*reflect.StructType) + srt := ut.base engine := new(encEngine) seen := make(map[reflect.Type]*encOp) rt := ut.base if ut.isGobEncoder { rt = ut.user } - if !ut.isGobEncoder && isStruct { + if !ut.isGobEncoder && + srt.Kind() == reflect.Struct { for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ { f := srt.Field(fieldNum) if !isExported(f.Name) { @@ -649,7 +650,7 @@ func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { wireFieldNum++ } if srt.NumField() > 0 && len(engine.instr) == 0 { - errorf("gob: type %s has no exported fields", rt) + errorf("type %s has no exported fields", rt) } engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) } else { @@ -694,8 +695,8 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf value = reflect.Indirect(value) } if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { - enc.encodeStruct(b, engine, value.UnsafeAddr()) + enc.encodeStruct(b, engine, unsafeAddr(value)) } else { - enc.encodeSingle(b, engine, value.UnsafeAddr()) + enc.encodeSingle(b, engine, unsafeAddr(value)) } } diff --git a/libgo/go/gob/encoder.go b/libgo/go/gob/encoder.go index e52a4de29f7..65ee5bf67c8 100644 --- a/libgo/go/gob/encoder.go +++ b/libgo/go/gob/encoder.go @@ -97,7 +97,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp // Id: state.encodeInt(-int64(info.id)) // Type: - enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo) enc.writeMessage(w, state.b) if enc.err != nil { return @@ -109,12 +109,15 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp enc.sent[ut.user] = info.id } // Now send the inner types - switch st := actual.(type) { - case *reflect.StructType: + switch st := actual; st.Kind() { + case reflect.Struct: for i := 0; i < st.NumField(); i++ { enc.sendType(w, state, st.Field(i).Type) } - case reflect.ArrayOrSliceType: + case reflect.Array, reflect.Slice: + enc.sendType(w, state, st.Elem()) + case reflect.Map: + enc.sendType(w, state, st.Key()) enc.sendType(w, state, st.Elem()) } return true @@ -130,27 +133,27 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ } // It's a concrete value, so drill down to the base type. - switch rt := ut.base.(type) { + switch rt := ut.base; rt.Kind() { default: // Basic types and interfaces do not need to be described. return - case *reflect.SliceType: + case reflect.Slice: // If it's []uint8, don't send; it's considered basic. if rt.Elem().Kind() == reflect.Uint8 { return } // Otherwise we do send. break - case *reflect.ArrayType: + case reflect.Array: // arrays must be sent so we know their lengths and element types. break - case *reflect.MapType: + case reflect.Map: // maps must be sent so we know their lengths and key/value types. break - case *reflect.StructType: + case reflect.Struct: // structs must be sent so we know their fields. break - case *reflect.ChanType, *reflect.FuncType: + case reflect.Chan, reflect.Func: // Probably a bad field in a struct. enc.badType(rt) return @@ -162,7 +165,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ // Encode transmits the data item represented by the empty interface value, // guaranteeing that all necessary type information has been transmitted first. func (enc *Encoder) Encode(e interface{}) os.Error { - return enc.EncodeValue(reflect.NewValue(e)) + return enc.EncodeValue(reflect.ValueOf(e)) } // sendTypeDescriptor makes sure the remote side knows about this type. diff --git a/libgo/go/gob/encoder_test.go b/libgo/go/gob/encoder_test.go index a0c713b81df..792afbd7752 100644 --- a/libgo/go/gob/encoder_test.go +++ b/libgo/go/gob/encoder_test.go @@ -170,7 +170,7 @@ func TestTypeToPtrType(t *testing.T) { A int } t0 := Type0{7} - t0p := (*Type0)(nil) + t0p := new(Type0) if err := encAndDec(t0, t0p); err != nil { t.Error(err) } @@ -339,7 +339,7 @@ func TestSingletons(t *testing.T) { continue } // Get rid of the pointer in the rhs - val := reflect.NewValue(test.out).(*reflect.PtrValue).Elem().Interface() + val := reflect.ValueOf(test.out).Elem().Interface() if !reflect.DeepEqual(test.in, val) { t.Errorf("decoding singleton: expected %v got %v", test.in, val) } @@ -514,3 +514,38 @@ func TestNestedInterfaces(t *testing.T) { t.Fatalf("final value %d; expected %d", inner.A, 7) } } + +// The bugs keep coming. We forgot to send map subtypes before the map. + +type Bug1Elem struct { + Name string + Id int +} + +type Bug1StructMap map[string]Bug1Elem + +func bug1EncDec(in Bug1StructMap, out *Bug1StructMap) os.Error { + return nil +} + +func TestMapBug1(t *testing.T) { + in := make(Bug1StructMap) + in["val1"] = Bug1Elem{"elem1", 1} + in["val2"] = Bug1Elem{"elem2", 2} + + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(in) + if err != nil { + t.Fatal("encode:", err) + } + dec := NewDecoder(b) + out := make(Bug1StructMap) + err = dec.Decode(&out) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(in, out) { + t.Errorf("mismatch: %v %v", in, out) + } +} diff --git a/libgo/go/gob/error.go b/libgo/go/gob/error.go index b053761fbcd..bfd38fc16d3 100644 --- a/libgo/go/gob/error.go +++ b/libgo/go/gob/error.go @@ -22,8 +22,9 @@ type gobError struct { } // errorf is like error but takes Printf-style arguments to construct an os.Error. +// It always prefixes the message with "gob: ". func errorf(format string, args ...interface{}) { - error(fmt.Errorf(format, args...)) + error(fmt.Errorf("gob: "+format, args...)) } // error wraps the argument error and uses it as the argument to panic. diff --git a/libgo/go/gob/gobencdec_test.go b/libgo/go/gob/gobencdec_test.go index 012b0995662..e94534f4c33 100644 --- a/libgo/go/gob/gobencdec_test.go +++ b/libgo/go/gob/gobencdec_test.go @@ -24,6 +24,10 @@ type StringStruct struct { s string // not an exported field } +type ArrayStruct struct { + a [8192]byte // not an exported field +} + type Gobber int type ValueGobber string // encodes with a value, decodes with a pointer. @@ -74,6 +78,18 @@ func (g *StringStruct) GobDecode(data []byte) os.Error { return nil } +func (a *ArrayStruct) GobEncode() ([]byte, os.Error) { + return a.a[:], nil +} + +func (a *ArrayStruct) GobDecode(data []byte) os.Error { + if len(data) != len(a.a) { + return os.ErrorString("wrong length in array decode") + } + copy(a.a[:], data) + return nil +} + func (g *Gobber) GobEncode() ([]byte, os.Error) { return []byte(fmt.Sprintf("VALUE=%d", *g)), nil } @@ -138,6 +154,16 @@ type GobTestIndirectEncDec struct { G ***StringStruct // indirections to the receiver. } +type GobTestArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ArrayStruct // not a pointer. +} + +type GobTestIndirectArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ***ArrayStruct // indirections to a large receiver. +} + func TestGobEncoderField(t *testing.T) { b := new(bytes.Buffer) // First a field that's a structure. @@ -216,6 +242,64 @@ func TestGobEncoderIndirectField(t *testing.T) { } } +// Test with a large field with methods. +func TestGobEncoderArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestArrayEncDec + a.X = 17 + for i := range a.A.a { + a.A.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range x.A.a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + +// Test an indirection to a large field with methods. +func TestGobEncoderIndirectArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestIndirectArrayEncDec + a.X = 17 + var array ArrayStruct + ap := &array + app := &ap + a.A = &app + for i := range array.a { + array.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIndirectArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range (***x.A).a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + // As long as the fields have the same name and implement the // interface, we can cross-connect them. Not sure it's useful // and may even be bad but it works and it's hard to prevent diff --git a/libgo/go/gob/type.go b/libgo/go/gob/type.go index 305d41980a5..c5b8fb5d9d1 100644 --- a/libgo/go/gob/type.go +++ b/libgo/go/gob/type.go @@ -60,8 +60,8 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { // half speed. If they meet up, there's a cycle. slowpoke := ut.base // walks half as fast as ut.base for { - pt, ok := ut.base.(*reflect.PtrType) - if !ok { + pt := ut.base + if pt.Kind() != reflect.Ptr { break } ut.base = pt.Elem() @@ -70,12 +70,12 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) } if ut.indir%2 == 0 { - slowpoke = slowpoke.(*reflect.PtrType).Elem() + slowpoke = slowpoke.Elem() } ut.indir++ } - ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck) - ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck) + ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderInterfaceType) + ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderInterfaceType) userTypeCache[rt] = ut return } @@ -85,32 +85,16 @@ const ( gobDecodeMethodName = "GobDecode" ) -// implements returns whether the type implements the interface, as encoded -// in the check function. -func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool { - if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance - return false - } - return check(typ) -} - -// gobEncoderCheck makes the type assertion a boolean function. -func gobEncoderCheck(typ reflect.Type) bool { - _, ok := reflect.MakeZero(typ).Interface().(GobEncoder) - return ok -} - -// gobDecoderCheck makes the type assertion a boolean function. -func gobDecoderCheck(typ reflect.Type) bool { - _, ok := reflect.MakeZero(typ).Interface().(GobDecoder) - return ok -} +var ( + gobEncoderInterfaceType = reflect.TypeOf(new(GobEncoder)).Elem() + gobDecoderInterfaceType = reflect.TypeOf(new(GobDecoder)).Elem() +) // implementsInterface reports whether the type implements the -// interface. (The actual check is done through the provided function.) +// gobEncoder/gobDecoder interface. // It also returns the number of indirections required to get to the // implementation. -func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) { +func implementsInterface(typ, gobEncDecType reflect.Type) (success bool, indir int8) { if typ == nil { return } @@ -118,10 +102,10 @@ func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (s // The type might be a pointer and we need to keep // dereferencing to the base type until we find an implementation. for { - if implements(rt, check) { + if rt.Implements(gobEncDecType) { return true, indir } - if p, ok := rt.(*reflect.PtrType); ok { + if p := rt; p.Kind() == reflect.Ptr { indir++ if indir > 100 { // insane number of indirections return false, 0 @@ -132,9 +116,9 @@ func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (s break } // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. - if _, ok := typ.(*reflect.PtrType); !ok { + if typ.Kind() != reflect.Ptr { // Not a pointer, but does the pointer work? - if implements(reflect.PtrTo(typ), check) { + if reflect.PtrTo(typ).Implements(gobEncDecType) { return true, -1 } } @@ -243,18 +227,18 @@ var ( ) // Predefined because it's needed by the Decoder -var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var tWireType = mustGetTypeInfo(reflect.TypeOf(wireType{})).id var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) func init() { // Some magic numbers to make sure there are no surprises. checkId(16, tWireType) - checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id) - checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id) - checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id) - checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id) - checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) - checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id) + checkId(17, mustGetTypeInfo(reflect.TypeOf(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.TypeOf(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.TypeOf(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.TypeOf(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.TypeOf(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.TypeOf(mapType{})).id) builtinIdToType = make(map[typeId]gobType) for k, v := range idToType { @@ -268,7 +252,7 @@ func init() { } nextId = firstUserId registerBasics() - wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) + wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil))) } // Array type @@ -431,30 +415,30 @@ func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os. }() // Install the top-level type before the subtypes (e.g. struct before // fields) so recursive types can be constructed safely. - switch t := rt.(type) { + switch t := rt; t.Kind() { // All basic types are easy: they are predefined. - case *reflect.BoolType: + case reflect.Bool: return tBool.gobType(), nil - case *reflect.IntType: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return tInt.gobType(), nil - case *reflect.UintType: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return tUint.gobType(), nil - case *reflect.FloatType: + case reflect.Float32, reflect.Float64: return tFloat.gobType(), nil - case *reflect.ComplexType: + case reflect.Complex64, reflect.Complex128: return tComplex.gobType(), nil - case *reflect.StringType: + case reflect.String: return tString.gobType(), nil - case *reflect.InterfaceType: + case reflect.Interface: return tInterface.gobType(), nil - case *reflect.ArrayType: + case reflect.Array: at := newArrayType(name) types[rt] = at type0, err = getBaseType("", t.Elem()) @@ -472,7 +456,7 @@ func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os. at.init(type0, t.Len()) return at, nil - case *reflect.MapType: + case reflect.Map: mt := newMapType(name) types[rt] = mt type0, err = getBaseType("", t.Key()) @@ -486,7 +470,7 @@ func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os. mt.init(type0, type1) return mt, nil - case *reflect.SliceType: + case reflect.Slice: // []byte == []uint8 is a special case if t.Elem().Kind() == reflect.Uint8 { return tBytes.gobType(), nil @@ -500,7 +484,7 @@ func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os. st.init(type0) return st, nil - case *reflect.StructType: + case reflect.Struct: st := newStructType(name) types[rt] = st idToType[st.id()] = st @@ -569,7 +553,7 @@ func checkId(want, got typeId) { // used for building the basic types; called only from init(). the incoming // interface always refers to a pointer. func bootstrapType(name string, e interface{}, expect typeId) typeId { - rt := reflect.Typeof(e).(*reflect.PtrType).Elem() + rt := reflect.TypeOf(e).Elem() _, present := types[rt] if present { panic("bootstrap type already present: " + name + ", " + rt.String()) @@ -658,17 +642,17 @@ func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) { } t := info.id.gobType() - switch typ := rt.(type) { - case *reflect.ArrayType: + switch typ := rt; typ.Kind() { + case reflect.Array: info.wire = &wireType{ArrayT: t.(*arrayType)} - case *reflect.MapType: + case reflect.Map: info.wire = &wireType{MapT: t.(*mapType)} - case *reflect.SliceType: + case reflect.Slice: // []byte == []uint8 is a special case handled separately if typ.Elem().Kind() != reflect.Uint8 { info.wire = &wireType{SliceT: t.(*sliceType)} } - case *reflect.StructType: + case reflect.Struct: info.wire = &wireType{StructT: t.(*structType)} } typeInfoMap[rt] = info @@ -723,7 +707,7 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } - base := userType(reflect.Typeof(value)).base + base := userType(reflect.TypeOf(value)).base // Check for incompatible duplicates. if t, ok := nameToConcreteType[name]; ok && t != base { panic("gob: registering duplicate types for " + name) @@ -732,7 +716,7 @@ func RegisterName(name string, value interface{}) { panic("gob: registering duplicate names for " + base.String()) } // Store the name and type provided by the user.... - nameToConcreteType[name] = reflect.Typeof(value) + nameToConcreteType[name] = reflect.TypeOf(value) // but the flattened type in the type table, since that's what decode needs. concreteTypeToName[base] = name } @@ -745,14 +729,14 @@ func RegisterName(name string, value interface{}) { // between types and names is not a bijection. func Register(value interface{}) { // Default to printed representation for unnamed types - rt := reflect.Typeof(value) + rt := reflect.TypeOf(value) name := rt.String() // But for named types (or pointers to them), qualify with import path. // Dereference one pointer looking for a named type. star := "" if rt.Name() == "" { - if pt, ok := rt.(*reflect.PtrType); ok { + if pt := rt; pt.Kind() == reflect.Ptr { star = "*" rt = pt } diff --git a/libgo/go/gob/type_test.go b/libgo/go/gob/type_test.go index ffd1345e5c0..411ffb7971b 100644 --- a/libgo/go/gob/type_test.go +++ b/libgo/go/gob/type_test.go @@ -47,15 +47,15 @@ func TestBasic(t *testing.T) { // Reregister some basic types to check registration is idempotent. func TestReregistration(t *testing.T) { - newtyp := getTypeUnlocked("int", reflect.Typeof(int(0))) + newtyp := getTypeUnlocked("int", reflect.TypeOf(int(0))) if newtyp != tInt.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } - newtyp = getTypeUnlocked("uint", reflect.Typeof(uint(0))) + newtyp = getTypeUnlocked("uint", reflect.TypeOf(uint(0))) if newtyp != tUint.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } - newtyp = getTypeUnlocked("string", reflect.Typeof("hello")) + newtyp = getTypeUnlocked("string", reflect.TypeOf("hello")) if newtyp != tString.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } @@ -63,18 +63,18 @@ func TestReregistration(t *testing.T) { func TestArrayType(t *testing.T) { var a3 [3]int - a3int := getTypeUnlocked("foo", reflect.Typeof(a3)) - newa3int := getTypeUnlocked("bar", reflect.Typeof(a3)) + a3int := getTypeUnlocked("foo", reflect.TypeOf(a3)) + newa3int := getTypeUnlocked("bar", reflect.TypeOf(a3)) if a3int != newa3int { t.Errorf("second registration of [3]int creates new type") } var a4 [4]int - a4int := getTypeUnlocked("goo", reflect.Typeof(a4)) + a4int := getTypeUnlocked("goo", reflect.TypeOf(a4)) if a3int == a4int { t.Errorf("registration of [3]int creates same type as [4]int") } var b3 [3]bool - a3bool := getTypeUnlocked("", reflect.Typeof(b3)) + a3bool := getTypeUnlocked("", reflect.TypeOf(b3)) if a3int == a3bool { t.Errorf("registration of [3]bool creates same type as [3]int") } @@ -87,14 +87,14 @@ func TestArrayType(t *testing.T) { func TestSliceType(t *testing.T) { var s []int - sint := getTypeUnlocked("slice", reflect.Typeof(s)) + sint := getTypeUnlocked("slice", reflect.TypeOf(s)) var news []int - newsint := getTypeUnlocked("slice1", reflect.Typeof(news)) + newsint := getTypeUnlocked("slice1", reflect.TypeOf(news)) if sint != newsint { t.Errorf("second registration of []int creates new type") } var b []bool - sbool := getTypeUnlocked("", reflect.Typeof(b)) + sbool := getTypeUnlocked("", reflect.TypeOf(b)) if sbool == sint { t.Errorf("registration of []bool creates same type as []int") } @@ -107,14 +107,14 @@ func TestSliceType(t *testing.T) { func TestMapType(t *testing.T) { var m map[string]int - mapStringInt := getTypeUnlocked("map", reflect.Typeof(m)) + mapStringInt := getTypeUnlocked("map", reflect.TypeOf(m)) var newm map[string]int - newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm)) + newMapStringInt := getTypeUnlocked("map1", reflect.TypeOf(newm)) if mapStringInt != newMapStringInt { t.Errorf("second registration of map[string]int creates new type") } var b map[string]bool - mapStringBool := getTypeUnlocked("", reflect.Typeof(b)) + mapStringBool := getTypeUnlocked("", reflect.TypeOf(b)) if mapStringBool == mapStringInt { t.Errorf("registration of map[string]bool creates same type as map[string]int") } @@ -143,7 +143,7 @@ type Foo struct { } func TestStructType(t *testing.T) { - sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{})) + sstruct := getTypeUnlocked("Foo", reflect.TypeOf(Foo{})) str := sstruct.string() // If we can print it correctly, we built it correctly. expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }" diff --git a/libgo/go/hash/adler32/adler32.go b/libgo/go/hash/adler32/adler32.go index cd0c2599ac0..84943d9ae4c 100644 --- a/libgo/go/hash/adler32/adler32.go +++ b/libgo/go/hash/adler32/adler32.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the Adler-32 checksum. +// Package adler32 implements the Adler-32 checksum. // Defined in RFC 1950: // Adler-32 is composed of two sums accumulated per byte: s1 is // the sum of all bytes, s2 is the sum of all s1 values. Both sums @@ -43,8 +43,8 @@ func (d *digest) Size() int { return Size } // Add p to the running checksum a, b. func update(a, b uint32, p []byte) (aa, bb uint32) { - for i := 0; i < len(p); i++ { - a += uint32(p[i]) + for _, pi := range p { + a += uint32(pi) b += a // invariant: a <= b if b > (0xffffffff-255)/2 { diff --git a/libgo/go/hash/adler32/adler32_test.go b/libgo/go/hash/adler32/adler32_test.go index ffa5569bcdd..01f931c6859 100644 --- a/libgo/go/hash/adler32/adler32_test.go +++ b/libgo/go/hash/adler32/adler32_test.go @@ -5,6 +5,7 @@ package adler32 import ( + "bytes" "io" "testing" ) @@ -61,3 +62,16 @@ func TestGolden(t *testing.T) { } } } + +func BenchmarkGolden(b *testing.B) { + b.StopTimer() + c := New() + var buf bytes.Buffer + for _, g := range golden { + buf.Write([]byte(g.in)) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + c.Write(buf.Bytes()) + } +} diff --git a/libgo/go/hash/crc32/crc32.go b/libgo/go/hash/crc32/crc32.go index 2ab0c54919d..88a44997168 100644 --- a/libgo/go/hash/crc32/crc32.go +++ b/libgo/go/hash/crc32/crc32.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the 32-bit cyclic redundancy check, or CRC-32, checksum. -// See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for information. +// Package crc32 implements the 32-bit cyclic redundancy check, or CRC-32, +// checksum. See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for +// information. package crc32 import ( diff --git a/libgo/go/hash/crc64/crc64.go b/libgo/go/hash/crc64/crc64.go index 8443865645e..ae37e781cd0 100644 --- a/libgo/go/hash/crc64/crc64.go +++ b/libgo/go/hash/crc64/crc64.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the 64-bit cyclic redundancy check, or CRC-64, checksum. -// See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for information. +// Package crc64 implements the 64-bit cyclic redundancy check, or CRC-64, +// checksum. See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for +// information. package crc64 import ( diff --git a/libgo/go/hash/fnv/fnv.go b/libgo/go/hash/fnv/fnv.go index 66ab5a635bf..9a1c6a0f2db 100644 --- a/libgo/go/hash/fnv/fnv.go +++ b/libgo/go/hash/fnv/fnv.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The fnv package implements FNV-1 and FNV-1a, -// non-cryptographic hash functions created by -// Glenn Fowler, Landon Curt Noll, and Phong Vo. +// Package fnv implements FNV-1 and FNV-1a, non-cryptographic hash functions +// created by Glenn Fowler, Landon Curt Noll, and Phong Vo. // See http://isthe.com/chongo/tech/comp/fnv/. package fnv diff --git a/libgo/go/hash/fnv/fnv_test.go b/libgo/go/hash/fnv/fnv_test.go index 3ea3fe6f124..429230c80b4 100644 --- a/libgo/go/hash/fnv/fnv_test.go +++ b/libgo/go/hash/fnv/fnv_test.go @@ -154,7 +154,7 @@ func benchmark(b *testing.B, h hash.Hash) { b.ResetTimer() b.SetBytes(testDataSize) data := make([]byte, testDataSize) - for i, _ := range data { + for i := range data { data[i] = byte(i + 'a') } diff --git a/libgo/go/hash/hash.go b/libgo/go/hash/hash.go index 56ac259db13..3536c0b6a64 100644 --- a/libgo/go/hash/hash.go +++ b/libgo/go/hash/hash.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package hash provides interfaces for hash functions. package hash import "io" diff --git a/libgo/go/html/doc.go b/libgo/go/html/doc.go index 4f5dee72da3..55135c3d05f 100644 --- a/libgo/go/html/doc.go +++ b/libgo/go/html/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The html package implements an HTML5-compliant tokenizer and parser. +Package html implements an HTML5-compliant tokenizer and parser. Tokenization is done by creating a Tokenizer for an io.Reader r. It is the caller's responsibility to ensure that r provides UTF-8 encoded HTML. diff --git a/libgo/go/html/parse_test.go b/libgo/go/html/parse_test.go index fe955436c8a..3fa35d5dbe4 100644 --- a/libgo/go/html/parse_test.go +++ b/libgo/go/html/parse_test.go @@ -15,12 +15,6 @@ import ( "testing" ) -type devNull struct{} - -func (devNull) Write(p []byte) (int, os.Error) { - return len(p), nil -} - func pipeErr(err os.Error) io.Reader { pr, pw := io.Pipe() pw.CloseWithError(err) @@ -141,7 +135,7 @@ func TestParser(t *testing.T) { t.Fatal(err) } // Skip the #error section. - if _, err := io.Copy(devNull{}, <-rc); err != nil { + if _, err := io.Copy(ioutil.Discard, <-rc); err != nil { t.Fatal(err) } // Compare the parsed tree to the #document section. diff --git a/libgo/go/http/cgi/child.go b/libgo/go/http/cgi/child.go index c7d48b9eb3f..e1ad7ad3221 100644 --- a/libgo/go/http/cgi/child.go +++ b/libgo/go/http/cgi/child.go @@ -9,10 +9,12 @@ package cgi import ( "bufio" + "crypto/tls" "fmt" "http" "io" "io/ioutil" + "net" "os" "strconv" "strings" @@ -21,8 +23,16 @@ import ( // Request returns the HTTP request as represented in the current // environment. This assumes the current program is being run // by a web server in a CGI environment. +// The returned Request's Body is populated, if applicable. func Request() (*http.Request, os.Error) { - return requestFromEnvironment(envMap(os.Environ())) + r, err := RequestFromMap(envMap(os.Environ())) + if err != nil { + return nil, err + } + if r.ContentLength > 0 { + r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + } + return r, nil } func envMap(env []string) map[string]string { @@ -42,37 +52,44 @@ var skipHeader = map[string]bool{ "HTTP_USER_AGENT": true, } -func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { +// RequestFromMap creates an http.Request from CGI variables. +// The returned Request's Body field is not populated. +func RequestFromMap(params map[string]string) (*http.Request, os.Error) { r := new(http.Request) - r.Method = env["REQUEST_METHOD"] + r.Method = params["REQUEST_METHOD"] if r.Method == "" { return nil, os.NewError("cgi: no REQUEST_METHOD in environment") } + + r.Proto = params["SERVER_PROTOCOL"] + var ok bool + r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto) + if !ok { + return nil, os.NewError("cgi: invalid SERVER_PROTOCOL version") + } + r.Close = true r.Trailer = http.Header{} r.Header = http.Header{} - r.Host = env["HTTP_HOST"] - r.Referer = env["HTTP_REFERER"] - r.UserAgent = env["HTTP_USER_AGENT"] + r.Host = params["HTTP_HOST"] + r.Referer = params["HTTP_REFERER"] + r.UserAgent = params["HTTP_USER_AGENT"] - // CGI doesn't allow chunked requests, so these should all be accurate: - r.Proto = "HTTP/1.0" - r.ProtoMajor = 1 - r.ProtoMinor = 0 - r.TransferEncoding = nil - - if lenstr := env["CONTENT_LENGTH"]; lenstr != "" { + if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { clen, err := strconv.Atoi64(lenstr) if err != nil { return nil, os.NewError("cgi: bad CONTENT_LENGTH in environment: " + lenstr) } r.ContentLength = clen - r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, clen)) + } + + if ct := params["CONTENT_TYPE"]; ct != "" { + r.Header.Set("Content-Type", ct) } // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers - for k, v := range env { + for k, v := range params { if !strings.HasPrefix(k, "HTTP_") || skipHeader[k] { continue } @@ -84,7 +101,7 @@ func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { if r.Host != "" { // Hostname is provided, so we can reasonably construct a URL, // even if we have to assume 'http' for the scheme. - r.RawURL = "http://" + r.Host + env["REQUEST_URI"] + r.RawURL = "http://" + r.Host + params["REQUEST_URI"] url, err := http.ParseURL(r.RawURL) if err != nil { return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + r.RawURL) @@ -94,13 +111,25 @@ func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { // Fallback logic if we don't have a Host header or the URL // failed to parse if r.URL == nil { - r.RawURL = env["REQUEST_URI"] + r.RawURL = params["REQUEST_URI"] url, err := http.ParseURL(r.RawURL) if err != nil { return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + r.RawURL) } r.URL = url } + + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + + // Request.RemoteAddr has its port set by Go's standard http + // server, so we do here too. We don't have one, though, so we + // use a dummy one. + r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0") + return r, nil } @@ -139,10 +168,6 @@ func (r *response) Flush() { r.bufw.Flush() } -func (r *response) RemoteAddr() string { - return os.Getenv("REMOTE_ADDR") -} - func (r *response) Header() http.Header { return r.header } @@ -168,25 +193,7 @@ func (r *response) WriteHeader(code int) { r.header.Add("Content-Type", "text/html; charset=utf-8") } - // TODO: add a method on http.Header to write itself to an io.Writer? - // This is duplicated code. - for k, vv := range r.header { - for _, v := range vv { - v = strings.Replace(v, "\n", "", -1) - v = strings.Replace(v, "\r", "", -1) - v = strings.TrimSpace(v) - fmt.Fprintf(r.bufw, "%s: %s\r\n", k, v) - } - } - r.bufw.Write([]byte("\r\n")) + r.header.Write(r.bufw) + r.bufw.WriteString("\r\n") r.bufw.Flush() } - -func (r *response) UsingTLS() bool { - // There's apparently a de-facto standard for this. - // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 - if s := os.Getenv("HTTPS"); s == "on" || s == "ON" || s == "1" { - return true - } - return false -} diff --git a/libgo/go/http/cgi/child_test.go b/libgo/go/http/cgi/child_test.go index db0e09cf66a..d12947814e1 100644 --- a/libgo/go/http/cgi/child_test.go +++ b/libgo/go/http/cgi/child_test.go @@ -12,6 +12,7 @@ import ( func TestRequest(t *testing.T) { env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", "REQUEST_METHOD": "GET", "HTTP_HOST": "example.com", "HTTP_REFERER": "elsewhere", @@ -19,10 +20,13 @@ func TestRequest(t *testing.T) { "HTTP_FOO_BAR": "baz", "REQUEST_URI": "/path?a=b", "CONTENT_LENGTH": "123", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", } - req, err := requestFromEnvironment(env) + req, err := RequestFromMap(env) if err != nil { - t.Fatalf("requestFromEnvironment: %v", err) + t.Fatalf("RequestFromMap: %v", err) } if g, e := req.UserAgent, "goclient"; e != g { t.Errorf("expected UserAgent %q; got %q", e, g) @@ -34,6 +38,9 @@ func TestRequest(t *testing.T) { // Tests that we don't put recognized headers in the map t.Errorf("expected User-Agent %q; got %q", e, g) } + if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { + t.Errorf("expected Content-Type %q; got %q", e, g) + } if g, e := req.ContentLength, int64(123); e != g { t.Errorf("expected ContentLength %d; got %d", e, g) } @@ -58,18 +65,25 @@ func TestRequest(t *testing.T) { if req.Trailer == nil { t.Errorf("unexpected nil Trailer") } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } + if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } } func TestRequestWithoutHost(t *testing.T) { env := map[string]string{ - "HTTP_HOST": "", - "REQUEST_METHOD": "GET", - "REQUEST_URI": "/path?a=b", - "CONTENT_LENGTH": "123", + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", } - req, err := requestFromEnvironment(env) + req, err := RequestFromMap(env) if err != nil { - t.Fatalf("requestFromEnvironment: %v", err) + t.Fatalf("RequestFromMap: %v", err) } if g, e := req.RawURL, "/path?a=b"; e != g { t.Errorf("expected RawURL %q; got %q", e, g) diff --git a/libgo/go/http/cgi/host.go b/libgo/go/http/cgi/host.go index 862acb6000e..7e4ccf881d9 100644 --- a/libgo/go/http/cgi/host.go +++ b/libgo/go/http/cgi/host.go @@ -15,8 +15,8 @@ package cgi import ( + "bufio" "bytes" - "encoding/line" "exec" "fmt" "http" @@ -51,6 +51,16 @@ type Handler struct { InheritEnv []string // environment variables to inherit from host, as "key" Logger *log.Logger // optional log for errors or nil to use log.Print Args []string // optional arguments to pass to child process + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 § 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -78,6 +88,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env := []string{ "SERVER_SOFTWARE=go", "SERVER_NAME=" + req.Host, + "SERVER_PROTOCOL=HTTP/1.1", "HTTP_HOST=" + req.Host, "GATEWAY_INTERFACE=CGI/1.1", "REQUEST_METHOD=" + req.Method, @@ -172,14 +183,14 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { go io.Copy(cmd.Stdin, req.Body) } - linebody := line.NewReader(cmd.Stdout, 1024) - headers := rw.Header() - statusCode := http.StatusOK + linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024) + headers := make(http.Header) + statusCode := 0 for { line, isPrefix, err := linebody.ReadLine() if isPrefix { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: long header line from subprocess.") + h.printf("cgi: long header line from subprocess.") return } if err == os.EOF { @@ -187,7 +198,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if err != nil { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: error reading headers: %v", err) + h.printf("cgi: error reading headers: %v", err) return } if len(line) == 0 { @@ -195,7 +206,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } parts := strings.Split(string(line), ":", 2) if len(parts) < 2 { - h.printf("CGI: bogus header line: %s", string(line)) + h.printf("cgi: bogus header line: %s", string(line)) continue } header, val := parts[0], parts[1] @@ -204,13 +215,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch { case header == "Status": if len(val) < 3 { - h.printf("CGI: bogus status (short): %q", val) + h.printf("cgi: bogus status (short): %q", val) return } code, err := strconv.Atoi(val[0:3]) if err != nil { - h.printf("CGI: bogus status: %q", val) - h.printf("CGI: line was %q", line) + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) return } statusCode = code @@ -218,11 +229,35 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { headers.Add(header, val) } } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + rw.WriteHeader(statusCode) _, err = io.Copy(rw, linebody) if err != nil { - h.printf("CGI: copy error: %v", err) + h.printf("cgi: copy error: %v", err) } } @@ -234,6 +269,37 @@ func (h *Handler) printf(format string, v ...interface{}) { } } +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.ParseURL(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + RawURL: path, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + func upperCaseAndUnderscore(rune int) int { switch { case rune >= 'a' && rune <= 'z': diff --git a/libgo/go/http/cgi/host_test.go b/libgo/go/http/cgi/host_test.go index e8084b1134e..9ac085f2f3a 100644 --- a/libgo/go/http/cgi/host_test.go +++ b/libgo/go/http/cgi/host_test.go @@ -271,3 +271,40 @@ Transfer-Encoding: chunked expected, got) } } + +func TestRedirect(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + if skipTest(t) { + return + } + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/libgo/go/http/client.go b/libgo/go/http/client.go index daba3a89b0c..d73cbc8550c 100644 --- a/libgo/go/http/client.go +++ b/libgo/go/http/client.go @@ -22,6 +22,16 @@ import ( // Client is not yet very configurable. type Client struct { Transport RoundTripper // if nil, DefaultTransport is used + + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via + // are the upcoming request and the requests made already, + // oldest first. If CheckRedirect returns an error, the client + // returns that error instead of issue the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) os.Error } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool { } // Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // // 301 (Moved Permanently) // 302 (Found) @@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { return DefaultClient.Get(url) } -// Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. // // 301 (Moved Permanently) // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical to the -// input URL unless redirects were followed. +// finalURL is the URL from which the response was fetched -- identical +// to the input URL unless redirects were followed. // // Caller should close r.Body when done reading from it. func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. - // TODO: set referrer header on redirects. var base *URL - // TODO: remove this hard-coded 10 and use the Client's policy - // (ClientConfig) instead. - for redirect := 0; ; redirect++ { - if redirect >= 10 { - err = os.ErrorString("stopped after 10 redirects") - break - } + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + for redirect := 0; ; redirect++ { var req Request req.Method = "GET" - req.ProtoMajor = 1 - req.ProtoMinor = 1 + req.Header = make(Header) if base == nil { req.URL, err = ParseURL(url) } else { @@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { if err != nil { break } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Referer = lastReq.URL.String() + } + + err = redirectChecker(&req, via) + if err != nil { + break + } + } + url = req.URL.String() if r, err = send(&req, c.Transport); err != nil { break @@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { break } base = req.URL + via = append(via, &req) continue } finalURL = url @@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { return } +func defaultCheckRedirect(req *Request, via []*Request) os.Error { + if len(via) >= 10 { + return os.ErrorString("stopped after 10 redirects") + } + return nil +} + // Post issues a POST to the specified URL. // // Caller should close r.Body when done reading from it. diff --git a/libgo/go/http/client_test.go b/libgo/go/http/client_test.go index 3a6f834253b..59d62c1c9d4 100644 --- a/libgo/go/http/client_test.go +++ b/libgo/go/http/client_test.go @@ -12,6 +12,7 @@ import ( "http/httptest" "io/ioutil" "os" + "strconv" "strings" "testing" ) @@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) { t.Errorf("expected non-nil request Header") } } + +func TestRedirects(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer, ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client, expected error %q, got %q", e, g) + } + + var checkErr os.Error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error { + lastVia = via + return checkErr + }} + _, finalUrl, err := c.Get(ts.URL) + if e, g := "", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = os.NewError("no redirects allowed") + _, finalUrl, err = c.Get(ts.URL) + if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { + t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + } +} diff --git a/libgo/go/http/cookie.go b/libgo/go/http/cookie.go index 2bb66e58e5c..cc51316438a 100644 --- a/libgo/go/http/cookie.go +++ b/libgo/go/http/cookie.go @@ -15,9 +15,9 @@ import ( "time" ) -// This implementation is done according to IETF draft-ietf-httpstate-cookie-23, found at +// This implementation is done according to RFC 6265: // -// http://tools.ietf.org/html/draft-ietf-httpstate-cookie-23 +// http://tools.ietf.org/html/rfc6265 // A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an // HTTP response or the Cookie header of an HTTP request. @@ -142,12 +142,12 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { var b bytes.Buffer for _, c := range kk { b.Reset() - fmt.Fprintf(&b, "%s=%s", c.Name, c.Value) + fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path)) + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) } if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain)) + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) } if len(c.Expires.Zone) > 0 { fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) @@ -225,7 +225,7 @@ func readCookies(h Header) []*Cookie { func writeCookies(w io.Writer, kk []*Cookie) os.Error { lines := make([]string, 0, len(kk)) for _, c := range kk { - lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value)) + lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", sanitizeName(c.Name), sanitizeValue(c.Value))) } sort.SortStrings(lines) for _, l := range lines { @@ -236,6 +236,19 @@ func writeCookies(w io.Writer, kk []*Cookie) os.Error { return nil } +func sanitizeName(n string) string { + n = strings.Replace(n, "\n", "-", -1) + n = strings.Replace(n, "\r", "-", -1) + return n +} + +func sanitizeValue(v string) string { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.Replace(v, ";", " ", -1) + return v +} + func unquoteCookieValue(v string) string { if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { return v[1 : len(v)-1] diff --git a/libgo/go/http/cookie_test.go b/libgo/go/http/cookie_test.go index db09970406b..a3ae85cd6c9 100644 --- a/libgo/go/http/cookie_test.go +++ b/libgo/go/http/cookie_test.go @@ -21,9 +21,13 @@ var writeSetCookiesTests = []struct { []*Cookie{ &Cookie{Name: "cookie-1", Value: "v$1"}, &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, }, "Set-Cookie: cookie-1=v$1\r\n" + - "Set-Cookie: cookie-2=two; Max-Age=3600\r\n", + "Set-Cookie: cookie-2=two; Max-Age=3600\r\n" + + "Set-Cookie: cookie-3=three; Domain=.example.com\r\n" + + "Set-Cookie: cookie-4=four; Path=/restricted/\r\n", }, } diff --git a/libgo/go/http/dump.go b/libgo/go/http/dump.go index 306c45bc2c9..358980f7cae 100644 --- a/libgo/go/http/dump.go +++ b/libgo/go/http/dump.go @@ -31,6 +31,8 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) { // DumpRequest is semantically a no-op, but in order to // dump the body, it reads the body data into memory and // changes req.Body to refer to the in-memory copy. +// The documentation for Request.Write details which fields +// of req are used. func DumpRequest(req *Request, body bool) (dump []byte, err os.Error) { var b bytes.Buffer save := req.Body diff --git a/libgo/go/http/export_test.go b/libgo/go/http/export_test.go index a76b70760df..3fe658641f8 100644 --- a/libgo/go/http/export_test.go +++ b/libgo/go/http/export_test.go @@ -14,7 +14,7 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { if t.idleConn == nil { return } - for key, _ := range t.idleConn { + for key := range t.idleConn { keys = append(keys, key) } return @@ -32,3 +32,10 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { } return len(conns) } + +func NewTestTimeoutHandler(handler Handler, ch <-chan int64) Handler { + f := func() <-chan int64 { + return ch + } + return &timeoutHandler{handler, f, ""} +} diff --git a/libgo/go/http/fcgi/child.go b/libgo/go/http/fcgi/child.go new file mode 100644 index 00000000000..19718824c96 --- /dev/null +++ b/libgo/go/http/fcgi/child.go @@ -0,0 +1,258 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "fmt" + "http" + "http/cgi" + "io" + "net" + "os" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + w *bufWriter + wroteHeader bool +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(data []byte) (int, os.Error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.w.Write(data) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } else if r.header.Get("Content-Type") == "" { + r.header.Set("Content-Type", "text/html; charset=utf-8") + } + + if r.header.Get("Date") == "" { + r.header.Set("Date", time.UTC().Format(http.TimeFormat)) + } + + fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) + r.header.Write(r.w) + r.w.WriteString("\r\n") +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() os.Error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler +} + +func newChild(rwc net.Conn, handler http.Handler) *child { + return &child{newConn(rwc), handler} +} + +func (c *child) serve() { + requests := map[uint16]*request{} + defer c.conn.Close() + var rec record + var br beginRequest + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + + req, ok := requests[rec.h.Id] + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + continue + } + if ok && rec.h.Type == typeBeginRequest { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return + } + + switch rec.h.Type { + case typeBeginRequest: + if err := br.read(rec.content()); err != nil { + return + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + break + } + requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + break + } + req.parseParams() + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(0, typeGetValuesResult, values) + case typeData: + // If the filter role is implemented, read the data stream here. + case typeAbortRequest: + requests[rec.h.Id] = nil, false + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if !req.keepConn { + // connection will close upon return + return + } + default: + b := make([]byte, 8) + b[0] = rec.h.Type + c.conn.writeRecord(typeUnknownType, 0, b) + } + } +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := cgi.RequestFromMap(req.params) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.String())) + } else { + httpReq.Body = body + c.handler.ServeHTTP(r, httpReq) + } + if body != nil { + body.Close() + } + r.Close() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + if !req.keepConn { + c.conn.Close() + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// service thread for each. The service threads read requests and then call handler +// to reply to them. +// If l is nil, Serve accepts connections on stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) os.Error { + if l == nil { + var err os.Error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } + panic("unreachable") +} diff --git a/libgo/go/http/fcgi/fcgi.go b/libgo/go/http/fcgi/fcgi.go new file mode 100644 index 00000000000..8e2e1cd3cb3 --- /dev/null +++ b/libgo/go/http/fcgi/fcgi.go @@ -0,0 +1,271 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// Currently only the responder role is supported. +// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "io" + "os" + "sync" +) + +const ( + // Packet Types + typeBeginRequest = iota + 1 + typeAbortRequest + typeEndRequest + typeParams + typeStdin + typeStdout + typeStderr + typeData + typeGetValues + typeGetValuesResult + typeUnknownType +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +const headerLen = 8 + +type header struct { + Version uint8 + Type uint8 + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) os.Error { + if len(content) != 8 { + return os.NewError("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err os.Error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return os.NewError("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) os.Error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(typeBeginRequest, reqId, b[:]) +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) os.Error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) os.Error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(k))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() os.Error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w, _ := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType uint8 + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, os.Error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() os.Error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/libgo/go/http/fcgi/fcgi_test.go b/libgo/go/http/fcgi/fcgi_test.go new file mode 100644 index 00000000000..16a6243295e --- /dev/null +++ b/libgo/go/http/fcgi/fcgi_test.go @@ -0,0 +1,114 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "io" + "os" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType uint8 + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, typeStdout, 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + []byte{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + []byte{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + []byte{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() os.Error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} diff --git a/libgo/go/http/fs.go b/libgo/go/http/fs.go index c5efffca9cd..17d5297b82c 100644 --- a/libgo/go/http/fs.go +++ b/libgo/go/http/fs.go @@ -143,7 +143,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { n, _ := io.ReadFull(f, buf[:]) b := buf[:n] if isText(b) { - ctype = "text-plain; charset=utf-8" + ctype = "text/plain; charset=utf-8" } else { // generic binary ctype = "application/octet-stream" diff --git a/libgo/go/http/fs_test.go b/libgo/go/http/fs_test.go index 692b9863e82..09d0981f26e 100644 --- a/libgo/go/http/fs_test.go +++ b/libgo/go/http/fs_test.go @@ -104,7 +104,7 @@ func TestServeFileContentType(t *testing.T) { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } } - get("text-plain; charset=utf-8") + get("text/plain; charset=utf-8") override = true get(ctype) } diff --git a/libgo/go/http/header.go b/libgo/go/http/header.go index 95b0f3db6bb..95140b01f2a 100644 --- a/libgo/go/http/header.go +++ b/libgo/go/http/header.go @@ -4,7 +4,14 @@ package http -import "net/textproto" +import ( + "fmt" + "io" + "net/textproto" + "os" + "sort" + "strings" +) // A Header represents the key-value pairs in an HTTP header. type Header map[string][]string @@ -35,6 +42,37 @@ func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) os.Error { + return h.WriteSubset(w, nil) +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) os.Error { + keys := make([]string, 0, len(h)) + for k := range h { + if exclude == nil || !exclude[k] { + keys = append(keys, k) + } + } + sort.SortStrings(keys) + for _, k := range keys { + for _, v := range h[k] { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.TrimSpace(v) + if v == "" { + continue + } + if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { + return err + } + } + } + return nil +} + // CanonicalHeaderKey returns the canonical format of the // header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; diff --git a/libgo/go/http/header_test.go b/libgo/go/http/header_test.go new file mode 100644 index 00000000000..7e24cb069c6 --- /dev/null +++ b/libgo/go/http/header_test.go @@ -0,0 +1,71 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "testing" +) + +var headerWriteTests = []struct { + h Header + exclude map[string]bool + expected string +}{ + {Header{}, nil, ""}, + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "Content-Length": {"0"}, + }, + nil, + "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n", + }, + { + Header{ + "Content-Length": {"0", "1", "2"}, + }, + nil, + "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0", "1", "2"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, + "", + }, +} + +func TestHeaderWrite(t *testing.T) { + var buf bytes.Buffer + for i, test := range headerWriteTests { + test.h.WriteSubset(&buf, test.exclude) + if buf.String() != test.expected { + t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected) + } + buf.Reset() + } +} diff --git a/libgo/go/http/httptest/recorder.go b/libgo/go/http/httptest/recorder.go index 0dd19a617cc..f2fedefcfd1 100644 --- a/libgo/go/http/httptest/recorder.go +++ b/libgo/go/http/httptest/recorder.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The httptest package provides utilities for HTTP testing. +// Package httptest provides utilities for HTTP testing. package httptest import ( diff --git a/libgo/go/http/persist.go b/libgo/go/http/persist.go index b93c5fe4855..e4eea6815d0 100644 --- a/libgo/go/http/persist.go +++ b/libgo/go/http/persist.go @@ -20,8 +20,8 @@ var ( // A ServerConn reads requests and sends responses over an underlying // connection, until the HTTP keepalive logic commands an end. ServerConn -// does not close the underlying connection. Instead, the user calls Close -// and regains control over the connection. ServerConn supports pipe-lining, +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. type ServerConn struct { @@ -45,11 +45,11 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { return &ServerConn{c: c, r: r, pipereq: make(map[*Request]uint)} } -// Close detaches the ServerConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before Read has signaled the end of the keep-alive logic. The user -// should not call Close while Read or Write is in progress. -func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { sc.lk.Lock() defer sc.lk.Unlock() c = sc.c @@ -59,6 +59,15 @@ func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() os.Error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Read returns the next request on the wire. An ErrPersistEOF is returned if // it is gracefully determined that there are no more requests (e.g. after the // first request on an HTTP/1.0 connection, or after a Connection:close on a @@ -199,9 +208,9 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { } // A ClientConn sends request and receives headers over an underlying -// connection, while respecting the HTTP keepalive logic. ClientConn is not -// responsible for closing the underlying connection. One must call Close to -// regain control of that connection and deal with it as desired. +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -239,11 +248,11 @@ func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { return cc } -// Close detaches the ClientConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before the user or Read have signaled the end of the keep-alive -// logic. The user should not call Close while Read or Write is in progress. -func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { cc.lk.Lock() defer cc.lk.Unlock() c = cc.c @@ -253,6 +262,15 @@ func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() os.Error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Write writes a request. An ErrPersistEOF error is returned if the connection // has been closed in an HTTP keepalive sense. If req.Close equals true, the // keepalive connection is logically closed after this request and the opposing diff --git a/libgo/go/http/pprof/pprof.go b/libgo/go/http/pprof/pprof.go index bc79e218320..917c7f877a3 100644 --- a/libgo/go/http/pprof/pprof.go +++ b/libgo/go/http/pprof/pprof.go @@ -26,6 +26,7 @@ package pprof import ( "bufio" + "bytes" "fmt" "http" "os" @@ -88,10 +89,14 @@ func Profile(w http.ResponseWriter, r *http.Request) { func Symbol(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") + // We have to read the whole POST body before + // writing any output. Buffer the output here. + var buf bytes.Buffer + // We don't know how many symbols we have, but we // do have symbol information. Pprof only cares whether // this number is 0 (no symbols available) or > 0. - fmt.Fprintf(w, "num_symbols: 1\n") + fmt.Fprintf(&buf, "num_symbols: 1\n") var b *bufio.Reader if r.Method == "POST" { @@ -109,14 +114,19 @@ func Symbol(w http.ResponseWriter, r *http.Request) { if pc != 0 { f := runtime.FuncForPC(uintptr(pc)) if f != nil { - fmt.Fprintf(w, "%#x %s\n", pc, f.Name()) + fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name()) } } // Wait until here to check for err; the last // symbol will have an err because it doesn't end in +. if err != nil { + if err != os.EOF { + fmt.Fprintf(&buf, "reading request: %v\n", err) + } break } } + + w.Write(buf.Bytes()) } diff --git a/libgo/go/http/proxy_test.go b/libgo/go/http/proxy_test.go index 7050ef5ed06..308bf44b48a 100644 --- a/libgo/go/http/proxy_test.go +++ b/libgo/go/http/proxy_test.go @@ -16,9 +16,15 @@ var UseProxyTests = []struct { host string match bool }{ - {"localhost", false}, // match completely + // Never proxy localhost: + {"localhost:80", false}, + {"127.0.0.1", false}, + {"127.0.0.2", false}, + {"[::1]", false}, + {"[::2]", true}, // not a loopback address + {"barbaz.net", false}, // match as .barbaz.net - {"foobar.com:443", false}, // have a port but match + {"foobar.com", false}, // have a port but match {"foofoobar.com", true}, // not match as a part of foobar.com {"baz.com", true}, // not match as a part of barbaz.com {"localhost.net", true}, // not match as suffix of address @@ -29,19 +35,16 @@ var UseProxyTests = []struct { func TestUseProxy(t *testing.T) { oldenv := os.Getenv("NO_PROXY") - no_proxy := "foobar.com, .barbaz.net , localhost" - os.Setenv("NO_PROXY", no_proxy) defer os.Setenv("NO_PROXY", oldenv) + no_proxy := "foobar.com, .barbaz.net" + os.Setenv("NO_PROXY", no_proxy) + tr := &Transport{} for _, test := range UseProxyTests { - if tr.useProxy(test.host) != test.match { - if test.match { - t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) - } else { - t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy) - } + if tr.useProxy(test.host+":80") != test.match { + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } } } diff --git a/libgo/go/http/request.go b/libgo/go/http/request.go index d82894fab08..8545d75660a 100644 --- a/libgo/go/http/request.go +++ b/libgo/go/http/request.go @@ -4,9 +4,8 @@ // HTTP Request reading and parsing. -// The http package implements parsing of HTTP requests, replies, -// and URLs and provides an extensible HTTP server and a basic -// HTTP client. +// Package http implements parsing of HTTP requests, replies, and URLs and +// provides an extensible HTTP server and a basic HTTP client. package http import ( @@ -25,12 +24,17 @@ import ( ) const ( - maxLineLength = 4096 // assumed <= bufio.defaultBufSize - maxValueLength = 4096 - maxHeaderLines = 1024 - chunkSize = 4 << 10 // 4 KB chunks + maxLineLength = 4096 // assumed <= bufio.defaultBufSize + maxValueLength = 4096 + maxHeaderLines = 1024 + chunkSize = 4 << 10 // 4 KB chunks + defaultMaxMemory = 32 << 20 // 32 MB ) +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = os.ErrorString("http: no such file") + // HTTP request parsing errors. type ProtocolError struct { os.ErrorString @@ -65,9 +69,12 @@ var reqExcludeHeader = map[string]bool{ // A Request represents a parsed HTTP request header. type Request struct { - Method string // GET, POST, PUT, etc. - RawURL string // The raw URL given in the request. - URL *URL // Parsed URL. + Method string // GET, POST, PUT, etc. + RawURL string // The raw URL given in the request. + URL *URL // Parsed URL. + + // The protocol version for incoming requests. + // Outgoing requests always use HTTP/1.1. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -134,6 +141,10 @@ type Request struct { // The parsed form. Only available after ParseForm is called. Form map[string][]string + // The parsed multipart form, including file uploads. + // Only available after ParseMultipartForm is called. + MultipartForm *multipart.Form + // Trailer maps trailer keys to values. Like for Header, if the // response has multiple trailer lines with the same key, they will be // concatenated, delimited by commas. @@ -163,9 +174,30 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartFrom. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + // MultipartReader returns a MIME multipart reader if this is a // multipart/form-data POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. func (r *Request) MultipartReader() (multipart.Reader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, os.NewError("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, os.NewError("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader() +} + +func (r *Request) multipartReader() (multipart.Reader, os.Error) { v := r.Header.Get("Content-Type") if v == "" { return nil, ErrNotMultipart @@ -199,10 +231,14 @@ const defaultUserAgent = "Go http package" // UserAgent (defaults to defaultUserAgent) // Referer // Header +// Cookie +// ContentLength +// TransferEncoding // Body // -// If Body is present, Write forces "Transfer-Encoding: chunked" as a header -// and then closes Body when finished sending it. +// If Body is present but Content-Length is <= 0, Write adds +// "Transfer-Encoding: chunked" to the header. Body is closed after +// it is sent. func (req *Request) Write(w io.Writer) os.Error { return req.write(w, false) } @@ -264,7 +300,7 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { // from Request, and introduce Request methods along the lines of // Response.{GetHeader,AddHeader} and string constants for "Host", // "User-Agent" and "Referer". - err = writeSortedHeader(w, req.Header, reqExcludeHeader) + err = req.Header.WriteSubset(w, reqExcludeHeader) if err != nil { return err } @@ -420,6 +456,29 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) { return n, cr.err } +// NewRequest returns a new Request given a method, URL, and optional body. +func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { + u, err := ParseURL(url) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = ioutil.NopCloser(body) + } + req := &Request{ + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + return req, nil +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { @@ -549,7 +608,9 @@ func parseQuery(m map[string][]string, query string) (err os.Error) { return err } -// ParseForm parses the request body as a form for POST requests, or the raw query for GET requests. +// ParseForm parses the raw query. +// For POST requests, it also parses the request body as a form. +// ParseMultipartForm calls ParseForm automatically. // It is idempotent. func (r *Request) ParseForm() (err os.Error) { if r.Form != nil { @@ -567,18 +628,23 @@ func (r *Request) ParseForm() (err os.Error) { ct := r.Header.Get("Content-Type") switch strings.Split(ct, ";", 2)[0] { case "text/plain", "application/x-www-form-urlencoded", "": - b, e := ioutil.ReadAll(r.Body) + const maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1)) if e != nil { if err == nil { err = e } break } + if int64(len(b)) > maxFormSize { + return os.NewError("http: POST too large") + } e = parseQuery(r.Form, string(b)) if err == nil { err = e } - // TODO(dsymonds): Handle multipart/form-data + case "multipart/form-data": + // handled by ParseMultipartForm default: return &badStringError{"unknown Content-Type", ct} } @@ -586,11 +652,50 @@ func (r *Request) ParseForm() (err os.Error) { return err } +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) os.Error { + if r.Form == nil { + err := r.ParseForm() + if err != nil { + return err + } + } + if r.MultipartForm != nil { + return nil + } + if r.MultipartForm == multipartByReader { + return os.NewError("http: multipart handled by MultipartReader") + } + + mr, err := r.multipartReader() + if err == ErrNotMultipart { + return nil + } else if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + } + r.MultipartForm = f + + return nil +} + // FormValue returns the first value for the named component of the query. -// FormValue calls ParseForm if necessary. +// FormValue calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormValue(key string) string { if r.Form == nil { - r.ParseForm() + r.ParseMultipartForm(defaultMaxMemory) } if vs := r.Form[key]; len(vs) > 0 { return vs[0] @@ -598,6 +703,27 @@ func (r *Request) FormValue(key string) string { return "" } +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, nil, os.NewError("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if r.MultipartForm != nil && r.MultipartForm.File != nil { + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + } + return nil, nil, ErrMissingFile +} + func (r *Request) expectsContinue() bool { return strings.ToLower(r.Header.Get("Expect")) == "100-continue" } diff --git a/libgo/go/http/request_test.go b/libgo/go/http/request_test.go index 19083adf624..f79d3a24240 100644 --- a/libgo/go/http/request_test.go +++ b/libgo/go/http/request_test.go @@ -10,6 +10,8 @@ import ( . "http" "http/httptest" "io" + "io/ioutil" + "mime/multipart" "os" "reflect" "regexp" @@ -82,7 +84,7 @@ func TestPostQuery(t *testing.T) { req.Header = Header{ "Content-Type": {"application/x-www-form-urlencoded; boo!"}, } - req.Body = nopCloser{strings.NewReader("z=post&both=y")} + req.Body = ioutil.NopCloser(strings.NewReader("z=post&both=y")) if q := req.FormValue("q"); q != "foo" { t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) } @@ -115,7 +117,7 @@ func TestPostContentTypeParsing(t *testing.T) { req := &Request{ Method: "POST", Header: Header(test.contentType), - Body: nopCloser{bytes.NewBufferString("body")}, + Body: ioutil.NopCloser(bytes.NewBufferString("body")), } err := req.ParseForm() if !test.error && err != nil { @@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: nopCloser{new(bytes.Buffer)}, + Body: ioutil.NopCloser(new(bytes.Buffer)), } multipart, err := req.MultipartReader() if multipart == nil { @@ -170,9 +172,143 @@ func TestRedirect(t *testing.T) { } } -// TODO: stop copy/pasting this around. move to io/ioutil? -type nopCloser struct { - io.Reader +func TestMultipartRequest(t *testing.T) { + // Test that we can read the values and files of a + // multipart request with FormValue and FormFile, + // and that ParseMultipartForm can be called multiple times. + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm first call:", err) + } + defer req.MultipartForm.RemoveAll() + validateTestMultipartContents(t, req, false) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm second call:", err) + } + validateTestMultipartContents(t, req, false) +} + +func TestMultipartRequestAuto(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req := newTestMultipartRequest(t) + defer func() { + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } + }() + validateTestMultipartContents(t, req, true) +} + +func TestEmptyMultipartRequest(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req, err := NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("NewRequest err = %q", err) + } + testMissingFile(t, req) +} + +func testMissingFile(t *testing.T, req *Request) { + f, fh, err := req.FormFile("missing") + if f != nil { + t.Errorf("FormFile file = %q, want nil", f, nil) + } + if fh != nil { + t.Errorf("FormFile file header = %q, want nil", fh, nil) + } + if err != ErrMissingFile { + t.Errorf("FormFile err = %q, want nil", err, ErrMissingFile) + } } -func (nopCloser) Close() os.Error { return nil } +func newTestMultipartRequest(t *testing.T) *Request { + b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + req, err := NewRequest("POST", "/", b) + if err != nil { + t.Fatalf("NewRequest:", err) + } + ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) + req.Header.Set("Content-type", ctype) + return req +} + +func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g := req.FormValue("missing"); g != "" { + t.Errorf("missing value = %q, want empty string", g) + } + + assertMem := func(n string, fd multipart.File) { + if _, ok := fd.(*os.File); ok { + t.Error(n, " is *os.File, should not be") + } + } + fd := testMultipartFile(t, req, "filea", "filea.txt", fileaContents) + assertMem("filea", fd) + fd = testMultipartFile(t, req, "fileb", "fileb.txt", filebContents) + if allMem { + assertMem("fileb", fd) + } else { + if _, ok := fd.(*os.File); !ok { + t.Errorf("fileb has unexpected underlying type %T", fd) + } + } + + testMissingFile(t, req) +} + +func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { + f, fh, err := req.FormFile(key) + if err != nil { + t.Fatalf("FormFile(%q):", key, err) + } + if fh.Filename != expectFilename { + t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) + } + var b bytes.Buffer + _, err = io.Copy(&b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != expectContent { + t.Errorf("contents = %q, want %q", g, expectContent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/libgo/go/http/requestwrite_test.go b/libgo/go/http/requestwrite_test.go index 726baa26686..bb000c701ff 100644 --- a/libgo/go/http/requestwrite_test.go +++ b/libgo/go/http/requestwrite_test.go @@ -6,7 +6,10 @@ package http import ( "bytes" + "io" "io/ioutil" + "os" + "strings" "testing" ) @@ -133,6 +136,41 @@ var reqWriteTests = []reqWriteTest{ "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", }, + + // HTTP/1.1 POST with Content-Length, no chunking + { + Request{ + Method: "POST", + URL: &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + ContentLength: 6, + }, + + []byte("abcdef"), + + "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + "POST http://www.google.com/search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + // default to HTTP/1.1 { Request{ @@ -189,3 +227,26 @@ func TestRequestWrite(t *testing.T) { } } } + +type closeChecker struct { + io.Reader + closed bool +} + +func (rc *closeChecker) Close() os.Error { + rc.closed = true + return nil +} + +// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer +// inside a NopCloser. +func TestRequestWriteClosesBody(t *testing.T) { + rc := &closeChecker{Reader: strings.NewReader("my body")} + req, _ := NewRequest("GET", "http://foo.com/", rc) + buf := new(bytes.Buffer) + req.Write(buf) + if !rc.closed { + t.Error("body not closed after write") + } +} diff --git a/libgo/go/http/response.go b/libgo/go/http/response.go index 1f725ecdddd..a65c2b14df6 100644 --- a/libgo/go/http/response.go +++ b/libgo/go/http/response.go @@ -8,11 +8,9 @@ package http import ( "bufio" - "fmt" "io" "net/textproto" "os" - "sort" "strconv" "strings" ) @@ -192,7 +190,7 @@ func (resp *Response) Write(w io.Writer) os.Error { } // Rest of header - err = writeSortedHeader(w, resp.Header, respExcludeHeader) + err = resp.Header.WriteSubset(w, respExcludeHeader) if err != nil { return err } @@ -213,27 +211,3 @@ func (resp *Response) Write(w io.Writer) os.Error { // Success return nil } - -func writeSortedHeader(w io.Writer, h Header, exclude map[string]bool) os.Error { - keys := make([]string, 0, len(h)) - for k := range h { - if exclude == nil || !exclude[k] { - keys = append(keys, k) - } - } - sort.SortStrings(keys) - for _, k := range keys { - for _, v := range h[k] { - v = strings.Replace(v, "\n", " ", -1) - v = strings.Replace(v, "\r", " ", -1) - v = strings.TrimSpace(v) - if v == "" { - continue - } - if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err - } - } - } - return nil -} diff --git a/libgo/go/http/response_test.go b/libgo/go/http/response_test.go index ef67fdd2dc3..9e77c20c40b 100644 --- a/libgo/go/http/response_test.go +++ b/libgo/go/http/response_test.go @@ -7,8 +7,12 @@ package http import ( "bufio" "bytes" + "compress/gzip" + "crypto/rand" "fmt" + "os" "io" + "io/ioutil" "reflect" "testing" ) @@ -117,7 +121,9 @@ var respTests = []respTest{ "Transfer-Encoding: chunked\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + + "09\r\n" + + "continued\r\n" + "0\r\n" + "\r\n", @@ -134,7 +140,7 @@ var respTests = []respTest{ TransferEncoding: []string{"chunked"}, }, - "Body here\n", + "Body here\ncontinued", }, // Chunked response with Content-Length. @@ -186,6 +192,29 @@ var respTests = []respTest{ "", }, + // explicit Content-Length of 0. + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RequestMethod: "GET", + Header: Header{ + "Content-Length": {"0"}, + }, + Close: false, + ContentLength: 0, + }, + + "", + }, + // Status line without a Reason-Phrase, but trailing space. // (permitted by RFC 2616) { @@ -250,9 +279,107 @@ func TestReadResponse(t *testing.T) { } } +var readResponseCloseInMiddleTests = []struct { + chunked, compressed bool +}{ + {false, false}, + {true, false}, + {true, true}, +} + +// TestReadResponseCloseInMiddle tests that closing a body after +// reading only part of its contents advances the read to the end of +// the request, right up until the next request. +func TestReadResponseCloseInMiddle(t *testing.T) { + for _, test := range readResponseCloseInMiddleTests { + fatalf := func(format string, args ...interface{}) { + args = append([]interface{}{test.chunked, test.compressed}, args...) + t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) + } + checkErr := func(err os.Error, msg string) { + if err == nil { + return + } + fatalf(msg+": %v", err) + } + var buf bytes.Buffer + buf.WriteString("HTTP/1.1 200 OK\r\n") + if test.chunked { + buf.WriteString("Transfer-Encoding: chunked\r\n") + } else { + buf.WriteString("Content-Length: 1000000\r\n") + } + var wr io.Writer = &buf + if test.chunked { + wr = &chunkedWriter{wr} + } + if test.compressed { + buf.WriteString("Content-Encoding: gzip\r\n") + var err os.Error + wr, err = gzip.NewWriter(wr) + checkErr(err, "gzip.NewWriter") + } + buf.WriteString("\r\n") + + chunk := bytes.Repeat([]byte{'x'}, 1000) + for i := 0; i < 1000; i++ { + if test.compressed { + // Otherwise this compresses too well. + _, err := io.ReadFull(rand.Reader, chunk) + checkErr(err, "rand.Reader ReadFull") + } + wr.Write(chunk) + } + if test.compressed { + err := wr.(*gzip.Compressor).Close() + checkErr(err, "compressor close") + } + if test.chunked { + buf.WriteString("0\r\n\r\n") + } + buf.WriteString("Next Request Here") + + bufr := bufio.NewReader(&buf) + resp, err := ReadResponse(bufr, "GET") + checkErr(err, "ReadResponse") + expectedLength := int64(-1) + if !test.chunked { + expectedLength = 1000000 + } + if resp.ContentLength != expectedLength { + fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength) + } + if resp.Body == nil { + fatalf("nil body") + } + if test.compressed { + gzReader, err := gzip.NewReader(resp.Body) + checkErr(err, "gzip.NewReader") + resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + } + + rbuf := make([]byte, 2500) + n, err := io.ReadFull(resp.Body, rbuf) + checkErr(err, "2500 byte ReadFull") + if n != 2500 { + fatalf("ReadFull only read %d bytes", n) + } + if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) { + fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf)) + } + resp.Body.Close() + + rest, err := ioutil.ReadAll(bufr) + checkErr(err, "ReadAll on remainder") + if e, g := "Next Request Here", string(rest); e != g { + fatalf("for chunked=%v remainder = %q, expected %q", g, e) + } + } +} + func diff(t *testing.T, prefix string, have, want interface{}) { - hv := reflect.NewValue(have).(*reflect.PtrValue).Elem().(*reflect.StructValue) - wv := reflect.NewValue(want).(*reflect.PtrValue).Elem().(*reflect.StructValue) + hv := reflect.ValueOf(have).Elem() + wv := reflect.ValueOf(want).Elem() if hv.Type() != wv.Type() { t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type()) } @@ -260,7 +387,7 @@ func diff(t *testing.T, prefix string, have, want interface{}) { hf := hv.Field(i).Interface() wf := wv.Field(i).Interface() if !reflect.DeepEqual(hf, wf) { - t.Errorf("%s: %s = %v want %v", prefix, hv.Type().(*reflect.StructType).Field(i).Name, hf, wf) + t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) } } } diff --git a/libgo/go/http/reverseproxy.go b/libgo/go/http/reverseproxy.go new file mode 100644 index 00000000000..e4ce1e34c79 --- /dev/null +++ b/libgo/go/http/reverseproxy.go @@ -0,0 +1,100 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package http + +import ( + "io" + "log" + "net" + "strings" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*Request) + + // The Transport used to perform proxy requests. + // If nil, DefaultTransport is used. + Transport RoundTripper +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *URL) *ReverseProxy { + director := func(req *Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if q := req.URL.RawQuery; q != "" { + req.URL.RawPath = req.URL.Path + "?" + q + } else { + req.URL.RawPath = req.URL.Path + } + req.URL.RawQuery = target.RawQuery + } + return &ReverseProxy{Director: director} +} + +func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) { + transport := p.Transport + if transport == nil { + transport = DefaultTransport + } + + outreq := new(Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + outreq.Header.Set("X-Forwarded-For", clientIp) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + log.Printf("http: proxy error: %v", err) + rw.WriteHeader(StatusInternalServerError) + return + } + + hdr := rw.Header() + for k, vv := range res.Header { + for _, v := range vv { + hdr.Add(k, v) + } + } + + rw.WriteHeader(res.StatusCode) + + if res.Body != nil { + io.Copy(rw, res.Body) + } +} diff --git a/libgo/go/http/reverseproxy_test.go b/libgo/go/http/reverseproxy_test.go new file mode 100644 index 00000000000..8cf7705d745 --- /dev/null +++ b/libgo/go/http/reverseproxy_test.go @@ -0,0 +1,50 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package http_test + +import ( + . "http" + "http/httptest" + "io/ioutil" + "testing" +) + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + w.Header().Set("X-Foo", "bar") + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := ParseURL(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, _, err := Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/libgo/go/http/serve_test.go b/libgo/go/http/serve_test.go index cf889553fb7..7ff6ef04b1a 100644 --- a/libgo/go/http/serve_test.go +++ b/libgo/go/http/serve_test.go @@ -231,7 +231,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { func TestServerTimeouts(t *testing.T) { // TODO(bradfitz): convert this to use httptest.Server - l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0}) + l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen error: %v", err) } @@ -247,7 +247,7 @@ func TestServerTimeouts(t *testing.T) { server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second} go server.Serve(l) - url := fmt.Sprintf("http://localhost:%d/", addr.Port) + url := fmt.Sprintf("http://%s/", addr) // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test @@ -265,7 +265,7 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Nanoseconds() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port)) + conn, err := net.Dial("tcp", addr.String()) if err != nil { t.Fatalf("Dial: %v", err) } @@ -534,3 +534,162 @@ func TestTLSServer(t *testing.T) { t.Errorf("expected body %q; got %q", e, g) } } + +type serverExpectTest struct { + contentLength int // of request body + expectation string // e.g. "100-continue" + readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) + expectedResponse string // expected substring in first line of http response +} + +var serverExpectTests = []serverExpectTest{ + // Normal 100-continues, case-insensitive. + {100, "100-continue", true, "100 Continue"}, + {100, "100-cOntInUE", true, "100 Continue"}, + + // No 100-continue. + {100, "", true, "200 OK"}, + + // 100-continue but requesting client to deny us, + // so it never eads the body. + {100, "100-continue", false, "401 Unauthorized"}, + // Likewise without 100-continue: + {100, "", false, "401 Unauthorized"}, + + // Non-standard expectations are failures + {0, "a-pony", false, "417 Expectation Failed"}, + + // Expect-100 requested but no body + {0, "100-continue", true, "400 Bad Request"}, +} + +// Tests that the server responds to the "Expect" request header +// correctly. +func TestServerExpect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Note using r.FormValue("readbody") because for POST + // requests that would read from r.Body, which we only + // conditionally want to do. + if strings.Contains(r.URL.RawPath, "readbody=true") { + ioutil.ReadAll(r.Body) + w.Write([]byte("Hi")) + } else { + w.WriteHeader(StatusUnauthorized) + } + })) + defer ts.Close() + + runTest := func(test serverExpectTest) { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + sendf := func(format string, args ...interface{}) { + _, err := fmt.Fprintf(conn, format, args...) + if err != nil { + t.Fatalf("On test %#v, error writing %q: %v", test, format, err) + } + } + go func() { + sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + "Connection: close\r\n"+ + "Content-Length: %d\r\n"+ + "Expect: %s\r\nHost: foo\r\n\r\n", + test.readBody, test.contentLength, test.expectation) + if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + body := strings.Repeat("A", test.contentLength) + sendf(body) + } + }() + bufr := bufio.NewReader(conn) + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatalf("ReadString: %v", err) + } + if !strings.Contains(line, test.expectedResponse) { + t.Errorf("for test %#v got first line=%q", test, line) + } + } + + for _, test := range serverExpectTests { + runTest(test) + } +} + +func TestServerConsumesRequestBody(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 1<<20) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n",len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + if g, e := conn.readBuf.Len(), 0; g != e { + t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) + } + done <- true + })) + <-done +} + +func TestTimeoutHandler(t *testing.T) { + sendHi := make(chan bool, 1) + writeErrors := make(chan os.Error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan int64, 1) // write to this to force timeouts + ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) + defer ts.Close() + + // Succeed without timing out: + sendHi <- true + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- 1 + res, _, err = Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "Timeout") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} diff --git a/libgo/go/http/server.go b/libgo/go/http/server.go index 8e7039371ae..d155f06a2d2 100644 --- a/libgo/go/http/server.go +++ b/libgo/go/http/server.go @@ -22,6 +22,7 @@ import ( "path" "strconv" "strings" + "sync" "time" ) @@ -141,9 +142,13 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { type expectContinueReader struct { resp *response readCloser io.ReadCloser + closed bool } func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { + if ecr.closed { + return 0, os.NewError("http: Read after Close on request Body") + } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") @@ -153,6 +158,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { } func (ecr *expectContinueReader) Close() os.Error { + ecr.closed = true return ecr.readCloser.Close() } @@ -180,12 +186,6 @@ func (c *conn) readRequest() (w *response, err os.Error) { w.req = req w.header = make(Header) w.contentLength = -1 - - // Expect 100 Continue support - if req.expectsContinue() && req.ProtoAtLeast(1, 1) { - // Wrap the Body reader with one that replies on the connection - req.Body = &expectContinueReader{readCloser: req.Body, resp: w} - } return w, nil } @@ -202,6 +202,16 @@ func (w *response) WriteHeader(code int) { log.Print("http: multiple response.WriteHeader calls") return } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. + if w.req.ContentLength != 0 { + ecr, isExpecter := w.req.Body.(*expectContinueReader) + if !isExpecter || ecr.resp.wroteContinue { + w.req.Body.Close() + } + } + w.wroteHeader = true w.status = code if code == StatusNotModified { @@ -299,7 +309,7 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - writeSortedHeader(w.conn.buf, w.header, nil) + w.header.Write(w.conn.buf) io.WriteString(w.conn.buf, "\r\n") } @@ -413,6 +423,9 @@ func (w *response) finishRequest() { } w.conn.buf.Flush() w.req.Body.Close() + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } if w.contentLength != -1 && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. @@ -446,6 +459,38 @@ func (c *conn) serve() { if err != nil { break } + + // Expect 100 Continue support + req := w.req + if req.expectsContinue() { + if req.ProtoAtLeast(1, 1) { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + } + if req.ContentLength == 0 { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) + break + } + req.Header.Del("Expect") + } else if req.Header.Get("Expect") != "" { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + break + } + // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. @@ -857,3 +902,89 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han tlsListener := tls.NewListener(conn, config) return Serve(tlsListener, handler) } + +// TimeoutHandler returns a Handler that runs h with the given time limit. +// +// The new Handler calls h.ServeHTTP to handle each request, but if a +// call runs for more than ns nanoseconds, the handler responds with +// a 503 Service Unavailable error and the given message in its body. +// (If msg is empty, a suitable default message will be sent.) +// After such a timeout, writes by h to its ResponseWriter will return +// ErrHandlerTimeout. +func TimeoutHandler(h Handler, ns int64, msg string) Handler { + f := func() <-chan int64 { + return time.After(ns) + } + return &timeoutHandler{h, f, msg} +} + +// ErrHandlerTimeout is returned on ResponseWriter Write calls +// in handlers which have timed out. +var ErrHandlerTimeout = os.NewError("http: Handler timeout") + +type timeoutHandler struct { + handler Handler + timeout func() <-chan int64 // returns channel producing a timeout + body string +} + +func (h *timeoutHandler) errorBody() string { + if h.body != "" { + return h.body + } + return "Timeout

Timeout

" +} + +func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + done := make(chan bool) + tw := &timeoutWriter{w: w} + go func() { + h.handler.ServeHTTP(tw, r) + done <- true + }() + select { + case <-done: + return + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader { + tw.w.WriteHeader(StatusServiceUnavailable) + tw.w.Write([]byte(h.errorBody())) + } + tw.timedOut = true + } +} + +type timeoutWriter struct { + w ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool +} + +func (tw *timeoutWriter) Header() Header { + return tw.w.Header() +} + +func (tw *timeoutWriter) Write(p []byte) (int, os.Error) { + tw.mu.Lock() + timedOut := tw.timedOut + tw.mu.Unlock() + if timedOut { + return 0, ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + if tw.timedOut || tw.wroteHeader { + tw.mu.Unlock() + return + } + tw.wroteHeader = true + tw.mu.Unlock() + tw.w.WriteHeader(code) +} diff --git a/libgo/go/http/spdy/protocol.go b/libgo/go/http/spdy/protocol.go new file mode 100644 index 00000000000..d584ea232ea --- /dev/null +++ b/libgo/go/http/spdy/protocol.go @@ -0,0 +1,367 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package spdy is an incomplete implementation of the SPDY protocol. +// +// The implementation follows draft 2 of the spec: +// https://sites.google.com/a/chromium.org/dev/spdy/spdy-protocol/spdy-protocol-draft2 +package spdy + +import ( + "bytes" + "compress/zlib" + "encoding/binary" + "http" + "io" + "os" + "strconv" + "strings" + "sync" +) + +// Version is the protocol version number that this package implements. +const Version = 2 + +// ControlFrameType stores the type field in a control frame header. +type ControlFrameType uint16 + +// Control frame type constants +const ( + TypeSynStream ControlFrameType = 0x0001 + TypeSynReply = 0x0002 + TypeRstStream = 0x0003 + TypeSettings = 0x0004 + TypeNoop = 0x0005 + TypePing = 0x0006 + TypeGoaway = 0x0007 + TypeHeaders = 0x0008 + TypeWindowUpdate = 0x0009 +) + +func (t ControlFrameType) String() string { + switch t { + case TypeSynStream: + return "SYN_STREAM" + case TypeSynReply: + return "SYN_REPLY" + case TypeRstStream: + return "RST_STREAM" + case TypeSettings: + return "SETTINGS" + case TypeNoop: + return "NOOP" + case TypePing: + return "PING" + case TypeGoaway: + return "GOAWAY" + case TypeHeaders: + return "HEADERS" + case TypeWindowUpdate: + return "WINDOW_UPDATE" + } + return "Type(" + strconv.Itoa(int(t)) + ")" +} + +type FrameFlags uint8 + +// Stream frame flags +const ( + FlagFin FrameFlags = 0x01 + FlagUnidirectional = 0x02 +) + +// SETTINGS frame flags +const ( + FlagClearPreviouslyPersistedSettings FrameFlags = 0x01 +) + +// MaxDataLength is the maximum number of bytes that can be stored in one frame. +const MaxDataLength = 1<<24 - 1 + +// A Frame is a framed message as sent between clients and servers. +// There are two types of frames: control frames and data frames. +type Frame struct { + Header [4]byte + Flags FrameFlags + Data []byte +} + +// ControlFrame creates a control frame with the given information. +func ControlFrame(t ControlFrameType, f FrameFlags, data []byte) Frame { + return Frame{ + Header: [4]byte{ + (Version&0xff00)>>8 | 0x80, + (Version & 0x00ff), + byte((t & 0xff00) >> 8), + byte((t & 0x00ff) >> 0), + }, + Flags: f, + Data: data, + } +} + +// DataFrame creates a data frame with the given information. +func DataFrame(streamId uint32, f FrameFlags, data []byte) Frame { + return Frame{ + Header: [4]byte{ + byte(streamId & 0x7f000000 >> 24), + byte(streamId & 0x00ff0000 >> 16), + byte(streamId & 0x0000ff00 >> 8), + byte(streamId & 0x000000ff >> 0), + }, + Flags: f, + Data: data, + } +} + +// ReadFrame reads an entire frame into memory. +func ReadFrame(r io.Reader) (f Frame, err os.Error) { + _, err = io.ReadFull(r, f.Header[:]) + if err != nil { + return + } + err = binary.Read(r, binary.BigEndian, &f.Flags) + if err != nil { + return + } + var lengthField [3]byte + _, err = io.ReadFull(r, lengthField[:]) + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + return + } + var length uint32 + length |= uint32(lengthField[0]) << 16 + length |= uint32(lengthField[1]) << 8 + length |= uint32(lengthField[2]) << 0 + if length > 0 { + f.Data = make([]byte, int(length)) + _, err = io.ReadFull(r, f.Data) + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + } else { + f.Data = []byte{} + } + return +} + +// IsControl returns whether the frame holds a control frame. +func (f Frame) IsControl() bool { + return f.Header[0]&0x80 != 0 +} + +// Type obtains the type field if the frame is a control frame, otherwise it returns zero. +func (f Frame) Type() ControlFrameType { + if !f.IsControl() { + return 0 + } + return (ControlFrameType(f.Header[2])<<8 | ControlFrameType(f.Header[3])) +} + +// StreamId returns the stream ID field if the frame is a data frame, otherwise it returns zero. +func (f Frame) StreamId() (id uint32) { + if f.IsControl() { + return 0 + } + id |= uint32(f.Header[0]) << 24 + id |= uint32(f.Header[1]) << 16 + id |= uint32(f.Header[2]) << 8 + id |= uint32(f.Header[3]) << 0 + return +} + +// WriteTo writes the frame in the SPDY format. +func (f Frame) WriteTo(w io.Writer) (n int64, err os.Error) { + var nn int + // Header + nn, err = w.Write(f.Header[:]) + n += int64(nn) + if err != nil { + return + } + // Flags + nn, err = w.Write([]byte{byte(f.Flags)}) + n += int64(nn) + if err != nil { + return + } + // Length + nn, err = w.Write([]byte{ + byte(len(f.Data) & 0x00ff0000 >> 16), + byte(len(f.Data) & 0x0000ff00 >> 8), + byte(len(f.Data) & 0x000000ff), + }) + n += int64(nn) + if err != nil { + return + } + // Data + if len(f.Data) > 0 { + nn, err = w.Write(f.Data) + n += int64(nn) + } + return +} + +// headerDictionary is the dictionary sent to the zlib compressor/decompressor. +// Even though the specification states there is no null byte at the end, Chrome sends it. +const headerDictionary = "optionsgetheadpostputdeletetrace" + + "acceptaccept-charsetaccept-encodingaccept-languageauthorizationexpectfromhost" + + "if-modified-sinceif-matchif-none-matchif-rangeif-unmodifiedsince" + + "max-forwardsproxy-authorizationrangerefererteuser-agent" + + "100101200201202203204205206300301302303304305306307400401402403404405406407408409410411412413414415416417500501502503504505" + + "accept-rangesageetaglocationproxy-authenticatepublicretry-after" + + "servervarywarningwww-authenticateallowcontent-basecontent-encodingcache-control" + + "connectiondatetrailertransfer-encodingupgradeviawarning" + + "content-languagecontent-lengthcontent-locationcontent-md5content-rangecontent-typeetagexpireslast-modifiedset-cookie" + + "MondayTuesdayWednesdayThursdayFridaySaturdaySunday" + + "JanFebMarAprMayJunJulAugSepOctNovDec" + + "chunkedtext/htmlimage/pngimage/jpgimage/gifapplication/xmlapplication/xhtmltext/plainpublicmax-age" + + "charset=iso-8859-1utf-8gzipdeflateHTTP/1.1statusversionurl\x00" + +// hrSource is a reader that passes through reads from another reader. +// When the underlying reader reaches EOF, Read will block until another reader is added via change. +type hrSource struct { + r io.Reader + m sync.RWMutex + c *sync.Cond +} + +func (src *hrSource) Read(p []byte) (n int, err os.Error) { + src.m.RLock() + for src.r == nil { + src.c.Wait() + } + n, err = src.r.Read(p) + src.m.RUnlock() + if err == os.EOF { + src.change(nil) + err = nil + } + return +} + +func (src *hrSource) change(r io.Reader) { + src.m.Lock() + defer src.m.Unlock() + src.r = r + src.c.Broadcast() +} + +// A HeaderReader reads zlib-compressed headers. +type HeaderReader struct { + source hrSource + decompressor io.ReadCloser +} + +// NewHeaderReader creates a HeaderReader with the initial dictionary. +func NewHeaderReader() (hr *HeaderReader) { + hr = new(HeaderReader) + hr.source.c = sync.NewCond(hr.source.m.RLocker()) + return +} + +// ReadHeader reads a set of headers from a reader. +func (hr *HeaderReader) ReadHeader(r io.Reader) (h http.Header, err os.Error) { + hr.source.change(r) + h, err = hr.read() + return +} + +// Decode reads a set of headers from a block of bytes. +func (hr *HeaderReader) Decode(data []byte) (h http.Header, err os.Error) { + hr.source.change(bytes.NewBuffer(data)) + h, err = hr.read() + return +} + +func (hr *HeaderReader) read() (h http.Header, err os.Error) { + var count uint16 + if hr.decompressor == nil { + hr.decompressor, err = zlib.NewReaderDict(&hr.source, []byte(headerDictionary)) + if err != nil { + return + } + } + err = binary.Read(hr.decompressor, binary.BigEndian, &count) + if err != nil { + return + } + h = make(http.Header, int(count)) + for i := 0; i < int(count); i++ { + var name, value string + name, err = readHeaderString(hr.decompressor) + if err != nil { + return + } + value, err = readHeaderString(hr.decompressor) + if err != nil { + return + } + valueList := strings.Split(string(value), "\x00", -1) + for _, v := range valueList { + h.Add(name, v) + } + } + return +} + +func readHeaderString(r io.Reader) (s string, err os.Error) { + var length uint16 + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return + } + data := make([]byte, int(length)) + _, err = io.ReadFull(r, data) + if err != nil { + return + } + return string(data), nil +} + +// HeaderWriter will write zlib-compressed headers on different streams. +type HeaderWriter struct { + compressor *zlib.Writer + buffer *bytes.Buffer +} + +// NewHeaderWriter creates a HeaderWriter ready to compress headers. +func NewHeaderWriter(level int) (hw *HeaderWriter) { + hw = &HeaderWriter{buffer: new(bytes.Buffer)} + hw.compressor, _ = zlib.NewWriterDict(hw.buffer, level, []byte(headerDictionary)) + return +} + +// WriteHeader writes a header block directly to an output. +func (hw *HeaderWriter) WriteHeader(w io.Writer, h http.Header) (err os.Error) { + hw.write(h) + _, err = io.Copy(w, hw.buffer) + hw.buffer.Reset() + return +} + +// Encode returns a compressed header block. +func (hw *HeaderWriter) Encode(h http.Header) (data []byte) { + hw.write(h) + data = make([]byte, hw.buffer.Len()) + hw.buffer.Read(data) + return +} + +func (hw *HeaderWriter) write(h http.Header) { + binary.Write(hw.compressor, binary.BigEndian, uint16(len(h))) + for k, vals := range h { + k = strings.ToLower(k) + binary.Write(hw.compressor, binary.BigEndian, uint16(len(k))) + binary.Write(hw.compressor, binary.BigEndian, []byte(k)) + v := strings.Join(vals, "\x00") + binary.Write(hw.compressor, binary.BigEndian, uint16(len(v))) + binary.Write(hw.compressor, binary.BigEndian, []byte(v)) + } + hw.compressor.Flush() +} diff --git a/libgo/go/http/spdy/protocol_test.go b/libgo/go/http/spdy/protocol_test.go new file mode 100644 index 00000000000..998ff998bc7 --- /dev/null +++ b/libgo/go/http/spdy/protocol_test.go @@ -0,0 +1,259 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "bytes" + "compress/zlib" + "http" + "os" + "testing" +) + +type frameIoTest struct { + desc string + data []byte + frame Frame + readError os.Error + readOnly bool +} + +var frameIoTests = []frameIoTest{ + { + "noop frame", + []byte{ + 0x80, 0x02, 0x00, 0x05, + 0x00, 0x00, 0x00, 0x00, + }, + ControlFrame( + TypeNoop, + 0x00, + []byte{}, + ), + nil, + false, + }, + { + "ping frame", + []byte{ + 0x80, 0x02, 0x00, 0x06, + 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x01, + }, + ControlFrame( + TypePing, + 0x00, + []byte{0x00, 0x00, 0x00, 0x01}, + ), + nil, + false, + }, + { + "syn_stream frame", + []byte{ + 0x80, 0x02, 0x00, 0x01, + 0x01, 0x00, 0x00, 0x53, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x78, 0xbb, + 0xdf, 0xa2, 0x51, 0xb2, + 0x62, 0x60, 0x66, 0x60, + 0xcb, 0x4d, 0x2d, 0xc9, + 0xc8, 0x4f, 0x61, 0x60, + 0x4e, 0x4f, 0x2d, 0x61, + 0x60, 0x2e, 0x2d, 0xca, + 0x61, 0x10, 0xcb, 0x28, + 0x29, 0x29, 0xb0, 0xd2, + 0xd7, 0x2f, 0x2f, 0x2f, + 0xd7, 0x4b, 0xcf, 0xcf, + 0x4f, 0xcf, 0x49, 0xd5, + 0x4b, 0xce, 0xcf, 0xd5, + 0x67, 0x60, 0x2f, 0x4b, + 0x2d, 0x2a, 0xce, 0xcc, + 0xcf, 0x63, 0xe0, 0x00, + 0x29, 0xd0, 0x37, 0xd4, + 0x33, 0x04, 0x00, 0x00, + 0x00, 0xff, 0xff, + }, + ControlFrame( + TypeSynStream, + 0x01, + []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x78, 0xbb, + 0xdf, 0xa2, 0x51, 0xb2, + 0x62, 0x60, 0x66, 0x60, + 0xcb, 0x4d, 0x2d, 0xc9, + 0xc8, 0x4f, 0x61, 0x60, + 0x4e, 0x4f, 0x2d, 0x61, + 0x60, 0x2e, 0x2d, 0xca, + 0x61, 0x10, 0xcb, 0x28, + 0x29, 0x29, 0xb0, 0xd2, + 0xd7, 0x2f, 0x2f, 0x2f, + 0xd7, 0x4b, 0xcf, 0xcf, + 0x4f, 0xcf, 0x49, 0xd5, + 0x4b, 0xce, 0xcf, 0xd5, + 0x67, 0x60, 0x2f, 0x4b, + 0x2d, 0x2a, 0xce, 0xcc, + 0xcf, 0x63, 0xe0, 0x00, + 0x29, 0xd0, 0x37, 0xd4, + 0x33, 0x04, 0x00, 0x00, + 0x00, 0xff, 0xff, + }, + ), + nil, + false, + }, + { + "data frame", + []byte{ + 0x00, 0x00, 0x00, 0x05, + 0x01, 0x00, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, + }, + DataFrame( + 5, + 0x01, + []byte{0x01, 0x02, 0x03, 0x04}, + ), + nil, + false, + }, + { + "too much data", + []byte{ + 0x00, 0x00, 0x00, 0x05, + 0x01, 0x00, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, + }, + DataFrame( + 5, + 0x01, + []byte{0x01, 0x02, 0x03, 0x04}, + ), + nil, + true, + }, + { + "not enough data", + []byte{ + 0x00, 0x00, 0x00, 0x05, + }, + Frame{}, + os.EOF, + true, + }, +} + +func TestReadFrame(t *testing.T) { + for _, tt := range frameIoTests { + f, err := ReadFrame(bytes.NewBuffer(tt.data)) + if err != tt.readError { + t.Errorf("%s: ReadFrame: %s", tt.desc, err) + continue + } + if err == nil { + if !bytes.Equal(f.Header[:], tt.frame.Header[:]) { + t.Errorf("%s: header %q != %q", tt.desc, string(f.Header[:]), string(tt.frame.Header[:])) + } + if f.Flags != tt.frame.Flags { + t.Errorf("%s: flags %#02x != %#02x", tt.desc, f.Flags, tt.frame.Flags) + } + if !bytes.Equal(f.Data, tt.frame.Data) { + t.Errorf("%s: data %q != %q", tt.desc, string(f.Data), string(tt.frame.Data)) + } + } + } +} + +func TestWriteTo(t *testing.T) { + for _, tt := range frameIoTests { + if tt.readOnly { + continue + } + b := new(bytes.Buffer) + _, err := tt.frame.WriteTo(b) + if err != nil { + t.Errorf("%s: WriteTo: %s", tt.desc, err) + } + if !bytes.Equal(b.Bytes(), tt.data) { + t.Errorf("%s: data %q != %q", tt.desc, string(b.Bytes()), string(tt.data)) + } + } +} + +var headerDataTest = []byte{ + 0x78, 0xbb, 0xdf, 0xa2, + 0x51, 0xb2, 0x62, 0x60, + 0x66, 0x60, 0xcb, 0x4d, + 0x2d, 0xc9, 0xc8, 0x4f, + 0x61, 0x60, 0x4e, 0x4f, + 0x2d, 0x61, 0x60, 0x2e, + 0x2d, 0xca, 0x61, 0x10, + 0xcb, 0x28, 0x29, 0x29, + 0xb0, 0xd2, 0xd7, 0x2f, + 0x2f, 0x2f, 0xd7, 0x4b, + 0xcf, 0xcf, 0x4f, 0xcf, + 0x49, 0xd5, 0x4b, 0xce, + 0xcf, 0xd5, 0x67, 0x60, + 0x2f, 0x4b, 0x2d, 0x2a, + 0xce, 0xcc, 0xcf, 0x63, + 0xe0, 0x00, 0x29, 0xd0, + 0x37, 0xd4, 0x33, 0x04, + 0x00, 0x00, 0x00, 0xff, + 0xff, +} + +func TestReadHeader(t *testing.T) { + r := NewHeaderReader() + h, err := r.Decode(headerDataTest) + if err != nil { + t.Fatalf("Error: %v", err) + return + } + if len(h) != 3 { + t.Errorf("Header count = %d (expected 3)", len(h)) + } + if h.Get("Url") != "http://www.google.com/" { + t.Errorf("Url: %q != %q", h.Get("Url"), "http://www.google.com/") + } + if h.Get("Method") != "get" { + t.Errorf("Method: %q != %q", h.Get("Method"), "get") + } + if h.Get("Version") != "http/1.1" { + t.Errorf("Version: %q != %q", h.Get("Version"), "http/1.1") + } +} + +func TestWriteHeader(t *testing.T) { + for level := zlib.NoCompression; level <= zlib.BestCompression; level++ { + r := NewHeaderReader() + w := NewHeaderWriter(level) + for i := 0; i < 100; i++ { + b := new(bytes.Buffer) + gold := http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + } + w.WriteHeader(b, gold) + h, err := r.Decode(b.Bytes()) + if err != nil { + t.Errorf("(level=%d i=%d) Error: %v", level, i, err) + return + } + if len(h) != len(gold) { + t.Errorf("(level=%d i=%d) Header count = %d (expected %d)", level, i, len(h), len(gold)) + } + for k, _ := range h { + if h.Get(k) != gold.Get(k) { + t.Errorf("(level=%d i=%d) %s: %q != %q", level, i, k, h.Get(k), gold.Get(k)) + } + } + } + } +} diff --git a/libgo/go/http/transfer.go b/libgo/go/http/transfer.go index 41614f144fe..0fa8bed43aa 100644 --- a/libgo/go/http/transfer.go +++ b/libgo/go/http/transfer.go @@ -7,6 +7,7 @@ package http import ( "bufio" "io" + "io/ioutil" "os" "strconv" "strings" @@ -438,26 +439,39 @@ type body struct { hdr interface{} // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? + closed bool +} + +// ErrBodyReadAfterClose is returned when reading a Request Body after +// the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = os.NewError("http: invalid Read on closed request Body") + +func (b *body) Read(p []byte) (n int, err os.Error) { + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.Reader.Read(p) } func (b *body) Close() os.Error { + if b.closed { + return nil + } + defer func() { + b.closed = true + }() if b.hdr == nil && b.closing { // no trailer and closing the connection next. // no point in reading to EOF. return nil } - trashBuf := make([]byte, 1024) // local for thread safety - for { - _, err := b.Read(trashBuf) - if err == nil { - continue - } - if err == os.EOF { - break - } + if _, err := io.Copy(ioutil.Discard, b); err != nil { return err } + if b.hdr == nil { // not reading trailer return nil } diff --git a/libgo/go/http/transport.go b/libgo/go/http/transport.go index 797d134aa85..73a2c2191ea 100644 --- a/libgo/go/http/transport.go +++ b/libgo/go/http/transport.go @@ -6,6 +6,8 @@ package http import ( "bufio" + "bytes" + "compress/gzip" "crypto/tls" "encoding/base64" "fmt" @@ -39,8 +41,9 @@ type Transport struct { // TODO: tunable on timeout on cached connections // TODO: optional pipelining - IgnoreEnvironment bool // don't look at environment variables for proxy configuration - DisableKeepAlives bool + IgnoreEnvironment bool // don't look at environment variables for proxy configuration + DisableKeepAlives bool + DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle // (keep-alive) to keep to keep per-host. If zero, @@ -215,6 +218,9 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { conn, err := net.Dial("tcp", cm.addr()) if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } return nil, err } @@ -286,10 +292,28 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. +// addr is always a canonicalAddr with a host and port. func (t *Transport) useProxy(addr string) bool { if len(addr) == 0 { return true } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + if host == "localhost" { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { + // 127.0.0.0/8 loopback isn't proxied. + return false + } + if bytes.Equal(ip, net.IPv6loopback) { + return false + } + } + no_proxy := t.getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false @@ -474,6 +498,19 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { pc.mutateRequestFunc(req) } + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempted to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { + // Request gzip only, not deflate. Deflate is ambiguous and + // as universally supported anyway. + // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + requestedGzip = true + req.Header.Set("Accept-Encoding", "gzip") + } + pc.lk.Lock() pc.numExpectedResponses++ pc.lk.Unlock() @@ -490,6 +527,20 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() + + if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" { + re.res.Header.Del("Content-Encoding") + re.res.Header.Del("Content-Length") + re.res.ContentLength = -1 + esb := re.res.Body.(*bodyEOFSignal) + gzReader, err := gzip.NewReader(esb.body) + if err != nil { + pc.close() + return nil, err + } + esb.body = &readFirstCloseBoth{gzReader, esb.body} + } + return re.res, re.err } @@ -526,7 +577,7 @@ func responseIsKeepAlive(res *Response) bool { func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { resp, err = ReadResponse(r, requestMethod) if err == nil && resp.ContentLength != 0 { - resp.Body = &bodyEOFSignal{resp.Body, nil} + resp.Body = &bodyEOFSignal{body: resp.Body} } return } @@ -535,12 +586,16 @@ func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Res // once, right before the final Read() or Close() call returns, but after // EOF has been seen. type bodyEOFSignal struct { - body io.ReadCloser - fn func() + body io.ReadCloser + fn func() + isClosed bool } func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { n, err = es.body.Read(p) + if es.isClosed && n > 0 { + panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + } if err == os.EOF && es.fn != nil { es.fn() es.fn = nil @@ -549,6 +604,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { } func (es *bodyEOFSignal) Close() (err os.Error) { + es.isClosed = true err = es.body.Close() if err == nil && es.fn != nil { es.fn() @@ -556,3 +612,19 @@ func (es *bodyEOFSignal) Close() (err os.Error) { } return } + +type readFirstCloseBoth struct { + io.ReadCloser + io.Closer +} + +func (r *readFirstCloseBoth) Close() os.Error { + if err := r.ReadCloser.Close(); err != nil { + r.Closer.Close() + return err + } + if err := r.Closer.Close(); err != nil { + return err + } + return nil +} diff --git a/libgo/go/http/transport_test.go b/libgo/go/http/transport_test.go index e46f830c828..7610856738d 100644 --- a/libgo/go/http/transport_test.go +++ b/libgo/go/http/transport_test.go @@ -7,11 +7,16 @@ package http_test import ( + "bytes" + "compress/gzip" + "crypto/rand" "fmt" . "http" "http/httptest" + "io" "io/ioutil" "os" + "strconv" "testing" "time" ) @@ -24,7 +29,7 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("close") == "true" { w.Header().Set("Connection", "close") } - fmt.Fprintf(w, "%s", r.RemoteAddr) + w.Write([]byte(r.RemoteAddr)) }) // Two subsequent requests and verify their response is the same. @@ -177,35 +182,47 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - ch := make(chan string) + resch := make(chan string) + gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "%s", <-ch) + gotReq <- true + msg := <-resch + _, err := w.Write([]byte(msg)) + if err != nil { + t.Fatalf("Write: %v", err) + } })) defer ts.Close() maxIdleConns := 2 tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} c := &Client{Transport: tr} - // Start 3 outstanding requests (will hang until we write to - // ch) + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we we write to resch, though. donech := make(chan bool) doReq := func() { resp, _, err := c.Get(ts.URL) if err != nil { t.Error(err) } - ioutil.ReadAll(resp.Body) + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } donech <- true } go doReq() + <-gotReq go doReq() + <-gotReq go doReq() + <-gotReq if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) } - ch <- "res1" + resch <- "res1" <-donech keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { @@ -219,13 +236,13 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } - ch <- "res2" + resch <- "res2" <-donech if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after second response, expected %d idle conns; got %d", e, g) } - ch <- "res3" + resch <- "res3" <-donech if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after third response, still expected %d idle conns; got %d", e, g) @@ -239,26 +256,44 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { tr := &Transport{} c := &Client{Transport: tr} - fetch := func(n int) string { - res, _, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("error in req #%d, GET: %v", n, err) + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...interface{}) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(1e9 / int64(retries)) } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in req #%d, ReadAll: %v", n, err) + for retries >= 0 { + retries-- + res, _, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) } - res.Body.Close() - return string(body) + panic("unreachable") } - body1 := fetch(1) - body2 := fetch(2) + body1 := fetch(1, 0) + body2 := fetch(2, 0) ts.CloseClientConnections() // surprise! - time.Sleep(25e6) // idle for a bit (test is inherently racey, but expectedly) - body3 := fetch(3) + // This test has an expected race. Sleeping for 25 ms prevents + // it on most fast machines, causing the next fetch() call to + // succeed quickly. But if we do get errors, fetch() will retry 5 + // times with some delays between. + time.Sleep(25e6) + + body3 := fetch(3, 5) if body1 != body2 { t.Errorf("expected body1 and body2 to be equal") @@ -288,10 +323,10 @@ func TestTransportHeadResponses(t *testing.T) { t.Errorf("error on loop %d: %v", i, err) } if e, g := "123", res.Header.Get("Content-Length"); e != g { - t.Errorf("loop %d: expected Content-Length header of %q, got %q", e, g) + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) } if e, g := int64(0), res.ContentLength; e != g { - t.Errorf("loop %d: expected res.ContentLength of %v, got %v", e, g) + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } } } @@ -338,6 +373,7 @@ func TestTransportNilURL(t *testing.T) { req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 + req.Header = make(Header) tr := &Transport{} res, err := tr.RoundTrip(req) @@ -349,3 +385,147 @@ func TestTransportNilURL(t *testing.T) { t.Fatalf("Expected response body of %q; got %q", e, g) } } + +func TestTransportGzip(t *testing.T) { + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } + gz, _ := gzip.NewWriter(w) + gz.Write([]byte(testString)) + if req.FormValue("body") == "large" { + io.Copyn(gz, rand.Reader, nRandBytes) + } + gz.Close() + })) + defer ts.Close() + + for _, chunked := range []string{"1", "0"} { + c := &Client{Transport: &Transport{}} + + // First fetch something large, but only read some of it. + res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, _, err = c.Get(ts.URL + "?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })) + defer ts.Close() + + c := &Client{Transport: &Transport{}} + res, _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} diff --git a/libgo/go/http/url.go b/libgo/go/http/url.go index 0fc0cb2d76e..d7ee14ee84a 100644 --- a/libgo/go/http/url.go +++ b/libgo/go/http/url.go @@ -449,7 +449,7 @@ func ParseURLReference(rawurlref string) (url *URL, err os.Error) { // // There are redundant fields stored in the URL structure: // the String method consults Scheme, Path, Host, RawUserinfo, -// RawQuery, and Fragment, but not Raw, RawPath or Authority. +// RawQuery, and Fragment, but not Raw, RawPath or RawAuthority. func (url *URL) String() string { result := "" if url.Scheme != "" { diff --git a/libgo/go/image/decode_test.go b/libgo/go/image/decode_test.go index 0716ad9055b..fee537cf1a2 100644 --- a/libgo/go/image/decode_test.go +++ b/libgo/go/image/decode_test.go @@ -10,9 +10,11 @@ import ( "os" "testing" - // TODO(nigeltao): implement bmp, gif and tiff decoders. + // TODO(nigeltao): implement bmp decoder. + _ "image/gif" _ "image/jpeg" _ "image/png" + _ "image/tiff" ) const goldenFile = "testdata/video-001.png" @@ -26,11 +28,11 @@ var imageTests = []imageTest{ //{"testdata/video-001.bmp", 0}, // GIF images are restricted to a 256-color palette and the conversion // to GIF loses significant image quality. - //{"testdata/video-001.gif", 64<<8}, + {"testdata/video-001.gif", 64 << 8}, // JPEG is a lossy format and hence needs a non-zero tolerance. {"testdata/video-001.jpeg", 8 << 8}, {"testdata/video-001.png", 0}, - //{"testdata/video-001.tiff", 0}, + {"testdata/video-001.tiff", 0}, } func decode(filename string) (image.Image, string, os.Error) { diff --git a/libgo/go/image/format.go b/libgo/go/image/format.go index 1d541b09406..b4859325e1d 100644 --- a/libgo/go/image/format.go +++ b/libgo/go/image/format.go @@ -25,7 +25,8 @@ var formats []format // RegisterFormat registers an image format for use by Decode. // Name is the name of the format, like "jpeg" or "png". -// Magic is the magic prefix that identifies the format's encoding. +// Magic is the magic prefix that identifies the format's encoding. The magic +// string can contain "?" wildcards that each match any one byte. // Decode is the function that decodes the encoded image. // DecodeConfig is the function that decodes just its configuration. func RegisterFormat(name, magic string, decode func(io.Reader) (Image, os.Error), decodeConfig func(io.Reader) (Config, os.Error)) { @@ -46,11 +47,24 @@ func asReader(r io.Reader) reader { return bufio.NewReader(r) } -// sniff determines the format of r's data. +// Match returns whether magic matches b. Magic may contain "?" wildcards. +func match(magic string, b []byte) bool { + if len(magic) != len(b) { + return false + } + for i, c := range b { + if magic[i] != c && magic[i] != '?' { + return false + } + } + return true +} + +// Sniff determines the format of r's data. func sniff(r reader) format { for _, f := range formats { - s, err := r.Peek(len(f.magic)) - if err == nil && string(s) == f.magic { + b, err := r.Peek(len(f.magic)) + if err == nil && match(f.magic, b) { return f } } diff --git a/libgo/go/image/gif/reader.go b/libgo/go/image/gif/reader.go new file mode 100644 index 00000000000..d37f52689ee --- /dev/null +++ b/libgo/go/image/gif/reader.go @@ -0,0 +1,392 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package gif implements a GIF image decoder. +// +// The GIF specification is at http://www.w3.org/Graphics/GIF/spec-gif89a.txt. +package gif + +import ( + "bufio" + "compress/lzw" + "fmt" + "image" + "io" + "os" +) + +// If the io.Reader does not also have ReadByte, then decode will introduce its own buffering. +type reader interface { + io.Reader + io.ByteReader +} + +// Masks etc. +const ( + // Fields. + fColorMapFollows = 1 << 7 + + // Image fields. + ifInterlace = 1 << 6 + + // Graphic control flags. + gcTransparentColorSet = 1 << 0 +) + +// Section indicators. +const ( + sExtension = 0x21 + sImageDescriptor = 0x2C + sTrailer = 0x3B +) + +// Extensions. +const ( + eText = 0x01 // Plain Text + eGraphicControl = 0xF9 // Graphic Control + eComment = 0xFE // Comment + eApplication = 0xFF // Application +) + +// decoder is the type used to decode a GIF file. +type decoder struct { + r reader + + // From header. + vers string + width int + height int + flags byte + headerFields byte + backgroundIndex byte + loopCount int + delayTime int + + // Unused from header. + aspect byte + + // From image descriptor. + imageFields byte + + // From graphics control. + transparentIndex byte + + // Computed. + pixelSize uint + globalColorMap image.PalettedColorModel + + // Used when decoding. + delay []int + image []*image.Paletted + tmp [1024]byte // must be at least 768 so we can read color map +} + +// blockReader parses the block structure of GIF image data, which +// comprises (n, (n bytes)) blocks, with 1 <= n <= 255. It is the +// reader given to the LZW decoder, which is thus immune to the +// blocking. After the LZW decoder completes, there will be a 0-byte +// block remaining (0, ()), but under normal execution blockReader +// doesn't consume it, so it is handled in decode. +type blockReader struct { + r reader + slice []byte + tmp [256]byte +} + +func (b *blockReader) Read(p []byte) (n int, err os.Error) { + if len(p) == 0 { + return + } + if len(b.slice) > 0 { + n = copy(p, b.slice) + b.slice = b.slice[n:] + return + } + var blockLen uint8 + blockLen, err = b.r.ReadByte() + if err != nil { + return + } + if blockLen == 0 { + return 0, os.EOF + } + b.slice = b.tmp[0:blockLen] + if _, err = io.ReadFull(b.r, b.slice); err != nil { + return + } + return b.Read(p) +} + +// decode reads a GIF image from r and stores the result in d. +func (d *decoder) decode(r io.Reader, configOnly bool) os.Error { + // Add buffering if r does not provide ReadByte. + if rr, ok := r.(reader); ok { + d.r = rr + } else { + d.r = bufio.NewReader(r) + } + + err := d.readHeaderAndScreenDescriptor() + if err != nil { + return err + } + if configOnly { + return nil + } + + if d.headerFields&fColorMapFollows != 0 { + if d.globalColorMap, err = d.readColorMap(); err != nil { + return err + } + } + + d.image = nil + +Loop: + for err == nil { + var c byte + c, err = d.r.ReadByte() + if err == os.EOF { + break + } + switch c { + case sExtension: + err = d.readExtension() + + case sImageDescriptor: + var m *image.Paletted + m, err = d.newImageFromDescriptor() + if err != nil { + break + } + if d.imageFields&fColorMapFollows != 0 { + m.Palette, err = d.readColorMap() + if err != nil { + break + } + // TODO: do we set transparency in this map too? That would be + // d.setTransparency(m.Palette) + } else { + m.Palette = d.globalColorMap + } + var litWidth uint8 + litWidth, err = d.r.ReadByte() + if err != nil { + return err + } + if litWidth > 8 { + return fmt.Errorf("gif: pixel size in decode out of range: %d", litWidth) + } + // A wonderfully Go-like piece of magic. Unfortunately it's only at its + // best for 8-bit pixels. + lzwr := lzw.NewReader(&blockReader{r: d.r}, lzw.LSB, int(litWidth)) + if _, err = io.ReadFull(lzwr, m.Pix); err != nil { + break + } + + // There should be a "0" block remaining; drain that. + c, err = d.r.ReadByte() + if err != nil { + return err + } + if c != 0 { + return os.ErrorString("gif: extra data after image") + } + d.image = append(d.image, m) + d.delay = append(d.delay, d.delayTime) + d.delayTime = 0 // TODO: is this correct, or should we hold on to the value? + + case sTrailer: + break Loop + + default: + err = fmt.Errorf("gif: unknown block type: 0x%.2x", c) + } + } + if err != nil { + return err + } + if len(d.image) == 0 { + return io.ErrUnexpectedEOF + } + return nil +} + +func (d *decoder) readHeaderAndScreenDescriptor() os.Error { + _, err := io.ReadFull(d.r, d.tmp[0:13]) + if err != nil { + return err + } + d.vers = string(d.tmp[0:6]) + if d.vers != "GIF87a" && d.vers != "GIF89a" { + return fmt.Errorf("gif: can't recognize format %s", d.vers) + } + d.width = int(d.tmp[6]) + int(d.tmp[7])<<8 + d.height = int(d.tmp[8]) + int(d.tmp[9])<<8 + d.headerFields = d.tmp[10] + d.backgroundIndex = d.tmp[11] + d.aspect = d.tmp[12] + d.loopCount = -1 + d.pixelSize = uint(d.headerFields&7) + 1 + return nil +} + +func (d *decoder) readColorMap() (image.PalettedColorModel, os.Error) { + if d.pixelSize > 8 { + return nil, fmt.Errorf("gif: can't handle %d bits per pixel", d.pixelSize) + } + numColors := 1 << d.pixelSize + numValues := 3 * numColors + _, err := io.ReadFull(d.r, d.tmp[0:numValues]) + if err != nil { + return nil, fmt.Errorf("gif: short read on color map: %s", err) + } + colorMap := make(image.PalettedColorModel, numColors) + j := 0 + for i := range colorMap { + colorMap[i] = image.RGBAColor{d.tmp[j+0], d.tmp[j+1], d.tmp[j+2], 0xFF} + j += 3 + } + return colorMap, nil +} + +func (d *decoder) readExtension() os.Error { + extension, err := d.r.ReadByte() + if err != nil { + return err + } + size := 0 + switch extension { + case eText: + size = 13 + case eGraphicControl: + return d.readGraphicControl() + case eComment: + // nothing to do but read the data. + case eApplication: + b, err := d.r.ReadByte() + if err != nil { + return err + } + // The spec requires size be 11, but Adobe sometimes uses 10. + size = int(b) + default: + return fmt.Errorf("gif: unknown extension 0x%.2x", extension) + } + if size > 0 { + if _, err := d.r.Read(d.tmp[0:size]); err != nil { + return err + } + } + + // Application Extension with "NETSCAPE2.0" as string and 1 in data means + // this extension defines a loop count. + if extension == eApplication && string(d.tmp[:size]) == "NETSCAPE2.0" { + n, err := d.readBlock() + if n == 0 || err != nil { + return err + } + if n == 3 && d.tmp[0] == 1 { + d.loopCount = int(d.tmp[1]) | int(d.tmp[2])<<8 + } + } + for { + n, err := d.readBlock() + if n == 0 || err != nil { + return err + } + } + panic("unreachable") +} + +func (d *decoder) readGraphicControl() os.Error { + if _, err := io.ReadFull(d.r, d.tmp[0:6]); err != nil { + return fmt.Errorf("gif: can't read graphic control: %s", err) + } + d.flags = d.tmp[1] + d.delayTime = int(d.tmp[2]) | int(d.tmp[3])<<8 + if d.flags&gcTransparentColorSet != 0 { + d.transparentIndex = d.tmp[4] + d.setTransparency(d.globalColorMap) + } + return nil +} + +func (d *decoder) setTransparency(colorMap image.PalettedColorModel) { + if int(d.transparentIndex) < len(colorMap) { + colorMap[d.transparentIndex] = image.RGBAColor{} + } +} + +func (d *decoder) newImageFromDescriptor() (*image.Paletted, os.Error) { + if _, err := io.ReadFull(d.r, d.tmp[0:9]); err != nil { + return nil, fmt.Errorf("gif: can't read image descriptor: %s", err) + } + _ = int(d.tmp[0]) + int(d.tmp[1])<<8 // TODO: honor left value + _ = int(d.tmp[2]) + int(d.tmp[3])<<8 // TODO: honor top value + width := int(d.tmp[4]) + int(d.tmp[5])<<8 + height := int(d.tmp[6]) + int(d.tmp[7])<<8 + d.imageFields = d.tmp[8] + if d.imageFields&ifInterlace != 0 { + return nil, os.ErrorString("gif: can't handle interlaced images") + } + return image.NewPaletted(width, height, nil), nil +} + +func (d *decoder) readBlock() (int, os.Error) { + n, err := d.r.ReadByte() + if n == 0 || err != nil { + return 0, err + } + return io.ReadFull(d.r, d.tmp[0:n]) +} + +// Decode reads a GIF image from r and returns the first embedded +// image as an image.Image. +// Limitation: The file must be 8 bits per pixel and have no interlacing. +func Decode(r io.Reader) (image.Image, os.Error) { + var d decoder + if err := d.decode(r, false); err != nil { + return nil, err + } + return d.image[0], nil +} + +// GIF represents the possibly multiple images stored in a GIF file. +type GIF struct { + Image []*image.Paletted // The successive images. + Delay []int // The successive delay times, one per frame, in 100ths of a second. + LoopCount int // The loop count. +} + +// DecodeAll reads a GIF image from r and returns the sequential frames +// and timing information. +// Limitation: The file must be 8 bits per pixel and have no interlacing. +func DecodeAll(r io.Reader) (*GIF, os.Error) { + var d decoder + if err := d.decode(r, false); err != nil { + return nil, err + } + gif := &GIF{ + Image: d.image, + LoopCount: d.loopCount, + Delay: d.delay, + } + return gif, nil +} + +// DecodeConfig returns the color model and dimensions of a GIF image without +// decoding the entire image. +func DecodeConfig(r io.Reader) (image.Config, os.Error) { + var d decoder + if err := d.decode(r, true); err != nil { + return image.Config{}, err + } + colorMap := d.globalColorMap + return image.Config{colorMap, d.width, d.height}, nil +} + +func init() { + image.RegisterFormat("gif", "GIF8?a", Decode, DecodeConfig) +} diff --git a/libgo/go/image/image.go b/libgo/go/image/image.go index c0e96e1f7b1..4350acc8203 100644 --- a/libgo/go/image/image.go +++ b/libgo/go/image/image.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The image package implements a basic 2-D image library. +// Package image implements a basic 2-D image library. package image // A Config consists of an image's color model and dimensions. @@ -51,6 +51,13 @@ func (p *RGBA) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toRGBAColor(c).(RGBAColor) } +func (p *RGBA) SetRGBA(x, y int, c RGBAColor) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *RGBA) Opaque() bool { if p.Rect.Empty() { @@ -103,6 +110,13 @@ func (p *RGBA64) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toRGBA64Color(c).(RGBA64Color) } +func (p *RGBA64) SetRGBA64(x, y int, c RGBA64Color) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *RGBA64) Opaque() bool { if p.Rect.Empty() { @@ -155,6 +169,13 @@ func (p *NRGBA) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toNRGBAColor(c).(NRGBAColor) } +func (p *NRGBA) SetNRGBA(x, y int, c NRGBAColor) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *NRGBA) Opaque() bool { if p.Rect.Empty() { @@ -207,6 +228,13 @@ func (p *NRGBA64) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toNRGBA64Color(c).(NRGBA64Color) } +func (p *NRGBA64) SetNRGBA64(x, y int, c NRGBA64Color) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *NRGBA64) Opaque() bool { if p.Rect.Empty() { @@ -259,6 +287,13 @@ func (p *Alpha) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toAlphaColor(c).(AlphaColor) } +func (p *Alpha) SetAlpha(x, y int, c AlphaColor) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *Alpha) Opaque() bool { if p.Rect.Empty() { @@ -311,6 +346,13 @@ func (p *Alpha16) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toAlpha16Color(c).(Alpha16Color) } +func (p *Alpha16) SetAlpha16(x, y int, c Alpha16Color) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *Alpha16) Opaque() bool { if p.Rect.Empty() { @@ -363,6 +405,13 @@ func (p *Gray) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toGrayColor(c).(GrayColor) } +func (p *Gray) SetGray(x, y int, c GrayColor) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *Gray) Opaque() bool { return true @@ -401,6 +450,13 @@ func (p *Gray16) Set(x, y int, c Color) { p.Pix[y*p.Stride+x] = toGray16Color(c).(Gray16Color) } +func (p *Gray16) SetGray16(x, y int, c Gray16Color) { + if !p.Rect.Contains(Point{x, y}) { + return + } + p.Pix[y*p.Stride+x] = c +} + // Opaque scans the entire image and returns whether or not it is fully opaque. func (p *Gray16) Opaque() bool { return true diff --git a/libgo/go/image/jpeg/fdct.go b/libgo/go/image/jpeg/fdct.go new file mode 100644 index 00000000000..3f8be4e3260 --- /dev/null +++ b/libgo/go/image/jpeg/fdct.go @@ -0,0 +1,190 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +// This file implements a Forward Discrete Cosine Transformation. + +/* +It is based on the code in jfdctint.c from the Independent JPEG Group, +found at http://www.ijg.org/files/jpegsrc.v8c.tar.gz. + +The "LEGAL ISSUES" section of the README in that archive says: + +In plain English: + +1. We don't promise that this software works. (But if you find any bugs, + please let us know!) +2. You can use this software for whatever you want. You don't have to pay us. +3. You may not pretend that you wrote this software. If you use it in a + program, you must acknowledge somewhere in your documentation that + you've used the IJG code. + +In legalese: + +The authors make NO WARRANTY or representation, either express or implied, +with respect to this software, its quality, accuracy, merchantability, or +fitness for a particular purpose. This software is provided "AS IS", and you, +its user, assume the entire risk as to its quality and accuracy. + +This software is copyright (C) 1991-2011, Thomas G. Lane, Guido Vollbeding. +All Rights Reserved except as specified below. + +Permission is hereby granted to use, copy, modify, and distribute this +software (or portions thereof) for any purpose, without fee, subject to these +conditions: +(1) If any part of the source code for this software is distributed, then this +README file must be included, with this copyright and no-warranty notice +unaltered; and any additions, deletions, or changes to the original files +must be clearly indicated in accompanying documentation. +(2) If only executable code is distributed, then the accompanying +documentation must state that "this software is based in part on the work of +the Independent JPEG Group". +(3) Permission for use of this software is granted only if the user accepts +full responsibility for any undesirable consequences; the authors accept +NO LIABILITY for damages of any kind. + +These conditions apply to any software derived from or based on the IJG code, +not just to the unmodified library. If you use our work, you ought to +acknowledge us. + +Permission is NOT granted for the use of any IJG author's name or company name +in advertising or publicity relating to this software or products derived from +it. This software may be referred to only as "the Independent JPEG Group's +software". + +We specifically permit and encourage the use of this software as the basis of +commercial products, provided that all warranty or liability claims are +assumed by the product vendor. +*/ + +// Trigonometric constants in 13-bit fixed point format. +const ( + fix_0_298631336 = 2446 + fix_0_390180644 = 3196 + fix_0_541196100 = 4433 + fix_0_765366865 = 6270 + fix_0_899976223 = 7373 + fix_1_175875602 = 9633 + fix_1_501321110 = 12299 + fix_1_847759065 = 15137 + fix_1_961570560 = 16069 + fix_2_053119869 = 16819 + fix_2_562915447 = 20995 + fix_3_072711026 = 25172 +) + +const ( + constBits = 13 + pass1Bits = 2 + centerJSample = 128 +) + +// fdct performs a forward DCT on an 8x8 block of coefficients, including a +// level shift. +func fdct(b *block) { + // Pass 1: process rows. + for y := 0; y < 8; y++ { + x0 := b[y*8+0] + x1 := b[y*8+1] + x2 := b[y*8+2] + x3 := b[y*8+3] + x4 := b[y*8+4] + x5 := b[y*8+5] + x6 := b[y*8+6] + x7 := b[y*8+7] + + tmp0 := x0 + x7 + tmp1 := x1 + x6 + tmp2 := x2 + x5 + tmp3 := x3 + x4 + + tmp10 := tmp0 + tmp3 + tmp12 := tmp0 - tmp3 + tmp11 := tmp1 + tmp2 + tmp13 := tmp1 - tmp2 + + tmp0 = x0 - x7 + tmp1 = x1 - x6 + tmp2 = x2 - x5 + tmp3 = x3 - x4 + + b[y*8+0] = (tmp10 + tmp11 - 8*centerJSample) << pass1Bits + b[y*8+4] = (tmp10 - tmp11) << pass1Bits + z1 := (tmp12 + tmp13) * fix_0_541196100 + z1 += 1 << (constBits - pass1Bits - 1) + b[y*8+2] = (z1 + tmp12*fix_0_765366865) >> (constBits - pass1Bits) + b[y*8+6] = (z1 - tmp13*fix_1_847759065) >> (constBits - pass1Bits) + + tmp10 = tmp0 + tmp3 + tmp11 = tmp1 + tmp2 + tmp12 = tmp0 + tmp2 + tmp13 = tmp1 + tmp3 + z1 = (tmp12 + tmp13) * fix_1_175875602 + z1 += 1 << (constBits - pass1Bits - 1) + tmp0 = tmp0 * fix_1_501321110 + tmp1 = tmp1 * fix_3_072711026 + tmp2 = tmp2 * fix_2_053119869 + tmp3 = tmp3 * fix_0_298631336 + tmp10 = tmp10 * -fix_0_899976223 + tmp11 = tmp11 * -fix_2_562915447 + tmp12 = tmp12 * -fix_0_390180644 + tmp13 = tmp13 * -fix_1_961570560 + + tmp12 += z1 + tmp13 += z1 + b[y*8+1] = (tmp0 + tmp10 + tmp12) >> (constBits - pass1Bits) + b[y*8+3] = (tmp1 + tmp11 + tmp13) >> (constBits - pass1Bits) + b[y*8+5] = (tmp2 + tmp11 + tmp12) >> (constBits - pass1Bits) + b[y*8+7] = (tmp3 + tmp10 + tmp13) >> (constBits - pass1Bits) + } + // Pass 2: process columns. + // We remove pass1Bits scaling, but leave results scaled up by an overall factor of 8. + for x := 0; x < 8; x++ { + tmp0 := b[0*8+x] + b[7*8+x] + tmp1 := b[1*8+x] + b[6*8+x] + tmp2 := b[2*8+x] + b[5*8+x] + tmp3 := b[3*8+x] + b[4*8+x] + + tmp10 := tmp0 + tmp3 + 1<<(pass1Bits-1) + tmp12 := tmp0 - tmp3 + tmp11 := tmp1 + tmp2 + tmp13 := tmp1 - tmp2 + + tmp0 = b[0*8+x] - b[7*8+x] + tmp1 = b[1*8+x] - b[6*8+x] + tmp2 = b[2*8+x] - b[5*8+x] + tmp3 = b[3*8+x] - b[4*8+x] + + b[0*8+x] = (tmp10 + tmp11) >> pass1Bits + b[4*8+x] = (tmp10 - tmp11) >> pass1Bits + + z1 := (tmp12 + tmp13) * fix_0_541196100 + z1 += 1 << (constBits + pass1Bits - 1) + b[2*8+x] = (z1 + tmp12*fix_0_765366865) >> (constBits + pass1Bits) + b[6*8+x] = (z1 - tmp13*fix_1_847759065) >> (constBits + pass1Bits) + + tmp10 = tmp0 + tmp3 + tmp11 = tmp1 + tmp2 + tmp12 = tmp0 + tmp2 + tmp13 = tmp1 + tmp3 + z1 = (tmp12 + tmp13) * fix_1_175875602 + z1 += 1 << (constBits + pass1Bits - 1) + tmp0 = tmp0 * fix_1_501321110 + tmp1 = tmp1 * fix_3_072711026 + tmp2 = tmp2 * fix_2_053119869 + tmp3 = tmp3 * fix_0_298631336 + tmp10 = tmp10 * -fix_0_899976223 + tmp11 = tmp11 * -fix_2_562915447 + tmp12 = tmp12 * -fix_0_390180644 + tmp13 = tmp13 * -fix_1_961570560 + + tmp12 += z1 + tmp13 += z1 + b[1*8+x] = (tmp0 + tmp10 + tmp12) >> (constBits + pass1Bits) + b[3*8+x] = (tmp1 + tmp11 + tmp13) >> (constBits + pass1Bits) + b[5*8+x] = (tmp2 + tmp11 + tmp12) >> (constBits + pass1Bits) + b[7*8+x] = (tmp3 + tmp10 + tmp13) >> (constBits + pass1Bits) + } +} diff --git a/libgo/go/image/jpeg/idct.go b/libgo/go/image/jpeg/idct.go index 5189931105b..e5a2f40f5db 100644 --- a/libgo/go/image/jpeg/idct.go +++ b/libgo/go/image/jpeg/idct.go @@ -63,7 +63,7 @@ const ( // // For more on the actual algorithm, see Z. Wang, "Fast algorithms for the discrete W transform and // for the discrete Fourier transform", IEEE Trans. on ASSP, Vol. ASSP- 32, pp. 803-816, Aug. 1984. -func idct(b *[blockSize]int) { +func idct(b *block) { // Horizontal 1-D IDCT. for y := 0; y < 8; y++ { // If all the AC components are zero, then the IDCT is trivial. diff --git a/libgo/go/image/jpeg/reader.go b/libgo/go/image/jpeg/reader.go index fb9cb11bb7f..21a6fff9698 100644 --- a/libgo/go/image/jpeg/reader.go +++ b/libgo/go/image/jpeg/reader.go @@ -2,18 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The jpeg package implements a decoder for JPEG images, as defined in ITU-T T.81. +// Package jpeg implements a JPEG image decoder and encoder. +// +// JPEG is defined in ITU-T T.81: http://www.w3.org/Graphics/JPEG/itu-t81.pdf. package jpeg -// See http://www.w3.org/Graphics/JPEG/itu-t81.pdf - import ( "bufio" "image" + "image/ycbcr" "io" "os" ) +// TODO(nigeltao): fix up the doc comment style so that sentences start with +// the name of the type or function that they annotate. + // A FormatError reports that the input is not a valid JPEG. type FormatError string @@ -26,12 +30,14 @@ func (e UnsupportedError) String() string { return "unsupported JPEG feature: " // Component specification, specified in section B.2.2. type component struct { + h int // Horizontal sampling factor. + v int // Vertical sampling factor. c uint8 // Component identifier. - h uint8 // Horizontal sampling factor. - v uint8 // Vertical sampling factor. tq uint8 // Quantization table destination selector. } +type block [blockSize]int + const ( blockSize = 64 // A DCT block is 8x8. @@ -84,13 +90,13 @@ type Reader interface { type decoder struct { r Reader width, height int - image *image.RGBA + img *ycbcr.YCbCr ri int // Restart Interval. comps [nComponent]component huff [maxTc + 1][maxTh + 1]huffman - quant [maxTq + 1][blockSize]int + quant [maxTq + 1]block b bits - blocks [nComponent][maxH * maxV][blockSize]int + blocks [nComponent][maxH * maxV]block tmp [1024]byte } @@ -130,9 +136,9 @@ func (d *decoder) processSOF(n int) os.Error { } for i := 0; i < nComponent; i++ { hv := d.tmp[7+3*i] + d.comps[i].h = int(hv >> 4) + d.comps[i].v = int(hv & 0x0f) d.comps[i].c = d.tmp[6+3*i] - d.comps[i].h = hv >> 4 - d.comps[i].v = hv & 0x0f d.comps[i].tq = d.tmp[8+3*i] // We only support YCbCr images, and 4:4:4, 4:2:2 or 4:2:0 chroma downsampling ratios. This implies that // the (h, v) values for the Y component are either (1, 1), (2, 1) or (2, 2), and the @@ -176,71 +182,47 @@ func (d *decoder) processDQT(n int) os.Error { return nil } -// Set the Pixel (px, py)'s RGB value, based on its YCbCr value. -func (d *decoder) calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex int) { - y, cb, cr := d.blocks[0][lumaBlock][lumaIndex], d.blocks[1][0][chromaIndex], d.blocks[2][0][chromaIndex] - // The JFIF specification (http://www.w3.org/Graphics/JPEG/jfif3.pdf, page 3) gives the formula - // for translating YCbCr to RGB as: - // R = Y + 1.402 (Cr-128) - // G = Y - 0.34414 (Cb-128) - 0.71414 (Cr-128) - // B = Y + 1.772 (Cb-128) - yPlusHalf := 100000*y + 50000 - cb -= 128 - cr -= 128 - r := (yPlusHalf + 140200*cr) / 100000 - g := (yPlusHalf - 34414*cb - 71414*cr) / 100000 - b := (yPlusHalf + 177200*cb) / 100000 - if r < 0 { - r = 0 - } else if r > 255 { - r = 255 +// Clip x to the range [0, 255] inclusive. +func clip(x int) uint8 { + if x < 0 { + return 0 } - if g < 0 { - g = 0 - } else if g > 255 { - g = 255 + if x > 255 { + return 255 } - if b < 0 { - b = 0 - } else if b > 255 { - b = 255 - } - d.image.Pix[py*d.image.Stride+px] = image.RGBAColor{uint8(r), uint8(g), uint8(b), 0xff} + return uint8(x) } -// Convert the MCU from YCbCr to RGB. -func (d *decoder) convertMCU(mx, my, h0, v0 int) { - lumaBlock := 0 +// Store the MCU to the image. +func (d *decoder) storeMCU(mx, my int) { + h0, v0 := d.comps[0].h, d.comps[0].v + // Store the luma blocks. for v := 0; v < v0; v++ { for h := 0; h < h0; h++ { - chromaBase := 8*4*v + 4*h - py := 8 * (v0*my + v) - for y := 0; y < 8 && py < d.height; y++ { - px := 8 * (h0*mx + h) - lumaIndex := 8 * y - chromaIndex := chromaBase + 8*(y/v0) - for x := 0; x < 8 && px < d.width; x++ { - d.calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex) - if h0 == 1 { - chromaIndex += 1 - } else { - chromaIndex += x % 2 - } - lumaIndex++ - px++ + p := 8 * ((v0*my+v)*d.img.YStride + (h0*mx + h)) + for y := 0; y < 8; y++ { + for x := 0; x < 8; x++ { + d.img.Y[p] = clip(d.blocks[0][h0*v+h][8*y+x]) + p++ } - py++ + p += d.img.YStride - 8 } - lumaBlock++ } } + // Store the chroma blocks. + p := 8 * (my*d.img.CStride + mx) + for y := 0; y < 8; y++ { + for x := 0; x < 8; x++ { + d.img.Cb[p] = clip(d.blocks[1][0][8*y+x]) + d.img.Cr[p] = clip(d.blocks[2][0][8*y+x]) + p++ + } + p += d.img.CStride - 8 + } } // Specified in section B.2.3. func (d *decoder) processSOS(n int) os.Error { - if d.image == nil { - d.image = image.NewRGBA(d.width, d.height) - } if n != 4+2*nComponent { return UnsupportedError("SOS has wrong length") } @@ -255,7 +237,6 @@ func (d *decoder) processSOS(n int) os.Error { td uint8 // DC table selector. ta uint8 // AC table selector. } - h0, v0 := int(d.comps[0].h), int(d.comps[0].v) // The h and v values from the Y components. for i := 0; i < nComponent; i++ { cs := d.tmp[1+2*i] // Component selector. if cs != d.comps[i].c { @@ -265,17 +246,42 @@ func (d *decoder) processSOS(n int) os.Error { scanComps[i].ta = d.tmp[2+2*i] & 0x0f } // mxx and myy are the number of MCUs (Minimum Coded Units) in the image. - mxx := (d.width + 8*int(h0) - 1) / (8 * int(h0)) - myy := (d.height + 8*int(v0) - 1) / (8 * int(v0)) + h0, v0 := d.comps[0].h, d.comps[0].v // The h and v values from the Y components. + mxx := (d.width + 8*h0 - 1) / (8 * h0) + myy := (d.height + 8*v0 - 1) / (8 * v0) + if d.img == nil { + var subsampleRatio ycbcr.SubsampleRatio + n := h0 * v0 + switch n { + case 1: + subsampleRatio = ycbcr.SubsampleRatio444 + case 2: + subsampleRatio = ycbcr.SubsampleRatio422 + case 4: + subsampleRatio = ycbcr.SubsampleRatio420 + default: + panic("unreachable") + } + b := make([]byte, mxx*myy*(1*8*8*n+2*8*8)) + d.img = &ycbcr.YCbCr{ + Y: b[mxx*myy*(0*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+0*8*8)], + Cb: b[mxx*myy*(1*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+1*8*8)], + Cr: b[mxx*myy*(1*8*8*n+1*8*8) : mxx*myy*(1*8*8*n+2*8*8)], + SubsampleRatio: subsampleRatio, + YStride: mxx * 8 * h0, + CStride: mxx * 8, + Rect: image.Rect(0, 0, d.width, d.height), + } + } mcu, expectedRST := 0, uint8(rst0Marker) - var allZeroes [blockSize]int + var allZeroes block var dc [nComponent]int for my := 0; my < myy; my++ { for mx := 0; mx < mxx; mx++ { for i := 0; i < nComponent; i++ { qt := &d.quant[d.comps[i].tq] - for j := 0; j < int(d.comps[i].h*d.comps[i].v); j++ { + for j := 0; j < d.comps[i].h*d.comps[i].v; j++ { d.blocks[i][j] = allZeroes // Decode the DC coefficient, as specified in section F.2.2.1. @@ -299,20 +305,20 @@ func (d *decoder) processSOS(n int) os.Error { if err != nil { return err } - v0 := value >> 4 - v1 := value & 0x0f - if v1 != 0 { - k += int(v0) + val0 := value >> 4 + val1 := value & 0x0f + if val1 != 0 { + k += int(val0) if k > blockSize { return FormatError("bad DCT index") } - ac, err := d.receiveExtend(v1) + ac, err := d.receiveExtend(val1) if err != nil { return err } d.blocks[i][j][unzig[k]] = ac * qt[k] } else { - if v0 != 0x0f { + if val0 != 0x0f { break } k += 0x0f @@ -322,7 +328,7 @@ func (d *decoder) processSOS(n int) os.Error { idct(&d.blocks[i][j]) } // for j } // for i - d.convertMCU(mx, my, int(d.comps[0].h), int(d.comps[0].v)) + d.storeMCU(mx, my) mcu++ if d.ri > 0 && mcu%d.ri == 0 && mcu < mxx*myy { // A more sophisticated decoder could use RST[0-7] markers to resynchronize from corrupt input, @@ -431,7 +437,7 @@ func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, os.Error) { return nil, err } } - return d.image, nil + return d.img, nil } // Decode reads a JPEG image from r and returns it as an image.Image. diff --git a/libgo/go/image/jpeg/writer.go b/libgo/go/image/jpeg/writer.go new file mode 100644 index 00000000000..52b3dc4e2c1 --- /dev/null +++ b/libgo/go/image/jpeg/writer.go @@ -0,0 +1,553 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "bufio" + "image" + "image/ycbcr" + "io" + "os" +) + +// min returns the minimum of two integers. +func min(x, y int) int { + if x < y { + return x + } + return y +} + +// div returns a/b rounded to the nearest integer, instead of rounded to zero. +func div(a int, b int) int { + if a >= 0 { + return (a + (b >> 1)) / b + } + return -((-a + (b >> 1)) / b) +} + +// bitCount counts the number of bits needed to hold an integer. +var bitCount = [256]byte{ + 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, +} + +type quantIndex int + +const ( + quantIndexLuminance quantIndex = iota + quantIndexChrominance + nQuantIndex +) + +// unscaledQuant are the unscaled quantization tables. Each encoder copies and +// scales the tables according to its quality parameter. +var unscaledQuant = [nQuantIndex][blockSize]byte{ + // Luminance. + { + 16, 11, 10, 16, 24, 40, 51, 61, + 12, 12, 14, 19, 26, 58, 60, 55, + 14, 13, 16, 24, 40, 57, 69, 56, + 14, 17, 22, 29, 51, 87, 80, 62, + 18, 22, 37, 56, 68, 109, 103, 77, + 24, 35, 55, 64, 81, 104, 113, 92, + 49, 64, 78, 87, 103, 121, 120, 101, + 72, 92, 95, 98, 112, 100, 103, 99, + }, + // Chrominance. + { + 17, 18, 24, 47, 99, 99, 99, 99, + 18, 21, 26, 66, 99, 99, 99, 99, + 24, 26, 56, 99, 99, 99, 99, 99, + 47, 66, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + }, +} + +type huffIndex int + +const ( + huffIndexLuminanceDC huffIndex = iota + huffIndexLuminanceAC + huffIndexChrominanceDC + huffIndexChrominanceAC + nHuffIndex +) + +// huffmanSpec specifies a Huffman encoding. +type huffmanSpec struct { + // count[i] is the number of codes of length i bits. + count [16]byte + // value[i] is the decoded value of the i'th codeword. + value []byte +} + +// theHuffmanSpec is the Huffman encoding specifications. +// This encoder uses the same Huffman encoding for all images. +var theHuffmanSpec = [nHuffIndex]huffmanSpec{ + // Luminance DC. + { + [16]byte{0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + // Luminance AC. + { + [16]byte{0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 125}, + []byte{ + 0x01, 0x02, 0x03, 0x00, 0x04, 0x11, 0x05, 0x12, + 0x21, 0x31, 0x41, 0x06, 0x13, 0x51, 0x61, 0x07, + 0x22, 0x71, 0x14, 0x32, 0x81, 0x91, 0xa1, 0x08, + 0x23, 0x42, 0xb1, 0xc1, 0x15, 0x52, 0xd1, 0xf0, + 0x24, 0x33, 0x62, 0x72, 0x82, 0x09, 0x0a, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, + 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, + 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, + 0x7a, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, + 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, + 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, + 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, + 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, 0xc4, 0xc5, + 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, 0xd3, 0xd4, + 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xe1, 0xe2, + 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, + 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa, + }, + }, + // Chrominance DC. + { + [16]byte{0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + // Chrominance AC. + { + [16]byte{0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 119}, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x11, 0x04, 0x05, 0x21, + 0x31, 0x06, 0x12, 0x41, 0x51, 0x07, 0x61, 0x71, + 0x13, 0x22, 0x32, 0x81, 0x08, 0x14, 0x42, 0x91, + 0xa1, 0xb1, 0xc1, 0x09, 0x23, 0x33, 0x52, 0xf0, + 0x15, 0x62, 0x72, 0xd1, 0x0a, 0x16, 0x24, 0x34, + 0xe1, 0x25, 0xf1, 0x17, 0x18, 0x19, 0x1a, 0x26, + 0x27, 0x28, 0x29, 0x2a, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, + 0x49, 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, + 0x59, 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, + 0x79, 0x7a, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, + 0x97, 0x98, 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, + 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, + 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, + 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, + 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, + 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, + 0xea, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa, + }, + }, +} + +// huffmanLUT is a compiled look-up table representation of a huffmanSpec. +// Each value maps to a uint32 of which the 8 most significant bits hold the +// codeword size in bits and the 24 least significant bits hold the codeword. +// The maximum codeword size is 16 bits. +type huffmanLUT []uint32 + +func (h *huffmanLUT) init(s huffmanSpec) { + maxValue := 0 + for _, v := range s.value { + if int(v) > maxValue { + maxValue = int(v) + } + } + *h = make([]uint32, maxValue+1) + code, k := uint32(0), 0 + for i := 0; i < len(s.count); i++ { + nBits := uint32(i+1) << 24 + for j := uint8(0); j < s.count[i]; j++ { + (*h)[s.value[k]] = nBits | code + code++ + k++ + } + code <<= 1 + } +} + +// theHuffmanLUT are compiled representations of theHuffmanSpec. +var theHuffmanLUT [4]huffmanLUT + +func init() { + for i, s := range theHuffmanSpec { + theHuffmanLUT[i].init(s) + } +} + +// writer is a buffered writer. +type writer interface { + Flush() os.Error + Write([]byte) (int, os.Error) + WriteByte(byte) os.Error +} + +// encoder encodes an image to the JPEG format. +type encoder struct { + // w is the writer to write to. err is the first error encountered during + // writing. All attempted writes after the first error become no-ops. + w writer + err os.Error + // buf is a scratch buffer. + buf [16]byte + // bits and nBits are accumulated bits to write to w. + bits uint32 + nBits uint8 + // quant is the scaled quantization tables. + quant [nQuantIndex][blockSize]byte +} + +func (e *encoder) flush() { + if e.err != nil { + return + } + e.err = e.w.Flush() +} + +func (e *encoder) write(p []byte) { + if e.err != nil { + return + } + _, e.err = e.w.Write(p) +} + +func (e *encoder) writeByte(b byte) { + if e.err != nil { + return + } + e.err = e.w.WriteByte(b) +} + +// emit emits the least significant nBits bits of bits to the bitstream. +// The precondition is bits < 1<= 8 { + b := uint8(bits >> 24) + e.writeByte(b) + if b == 0xff { + e.writeByte(0x00) + } + bits <<= 8 + nBits -= 8 + } + e.bits, e.nBits = bits, nBits +} + +// emitHuff emits the given value with the given Huffman encoder. +func (e *encoder) emitHuff(h huffIndex, value int) { + x := theHuffmanLUT[h][value] + e.emit(x&(1<<24-1), uint8(x>>24)) +} + +// emitHuffRLE emits a run of runLength copies of value encoded with the given +// Huffman encoder. +func (e *encoder) emitHuffRLE(h huffIndex, runLength, value int) { + a, b := value, value + if a < 0 { + a, b = -value, value-1 + } + var nBits uint8 + if a < 0x100 { + nBits = bitCount[a] + } else { + nBits = 8 + bitCount[a>>8] + } + e.emitHuff(h, runLength<<4|int(nBits)) + if nBits > 0 { + e.emit(uint32(b)&(1<> 8) + e.buf[3] = uint8(markerlen & 0xff) + e.write(e.buf[:4]) +} + +// writeDQT writes the Define Quantization Table marker. +func (e *encoder) writeDQT() { + markerlen := 2 + for _, q := range e.quant { + markerlen += 1 + len(q) + } + e.writeMarkerHeader(dqtMarker, markerlen) + for i, q := range e.quant { + e.writeByte(uint8(i)) + e.write(q[:]) + } +} + +// writeSOF0 writes the Start Of Frame (Baseline) marker. +func (e *encoder) writeSOF0(size image.Point) { + markerlen := 8 + 3*nComponent + e.writeMarkerHeader(sof0Marker, markerlen) + e.buf[0] = 8 // 8-bit color. + e.buf[1] = uint8(size.Y >> 8) + e.buf[2] = uint8(size.Y & 0xff) + e.buf[3] = uint8(size.X >> 8) + e.buf[4] = uint8(size.X & 0xff) + e.buf[5] = nComponent + for i := 0; i < nComponent; i++ { + e.buf[3*i+6] = uint8(i + 1) + // We use 4:2:0 chroma subsampling. + e.buf[3*i+7] = "\x22\x11\x11"[i] + e.buf[3*i+8] = "\x00\x01\x01"[i] + } + e.write(e.buf[:3*(nComponent-1)+9]) +} + +// writeDHT writes the Define Huffman Table marker. +func (e *encoder) writeDHT() { + markerlen := 2 + for _, s := range theHuffmanSpec { + markerlen += 1 + 16 + len(s.value) + } + e.writeMarkerHeader(dhtMarker, markerlen) + for i, s := range theHuffmanSpec { + e.writeByte("\x00\x10\x01\x11"[i]) + e.write(s.count[:]) + e.write(s.value) + } +} + +// writeBlock writes a block of pixel data using the given quantization table, +// returning the post-quantized DC value of the DCT-transformed block. +func (e *encoder) writeBlock(b *block, q quantIndex, prevDC int) int { + fdct(b) + // Emit the DC delta. + dc := div(b[0], (8 * int(e.quant[q][0]))) + e.emitHuffRLE(huffIndex(2*q+0), 0, dc-prevDC) + // Emit the AC components. + h, runLength := huffIndex(2*q+1), 0 + for k := 1; k < blockSize; k++ { + ac := div(b[unzig[k]], (8 * int(e.quant[q][k]))) + if ac == 0 { + runLength++ + } else { + for runLength > 15 { + e.emitHuff(h, 0xf0) + runLength -= 16 + } + e.emitHuffRLE(h, runLength, ac) + runLength = 0 + } + } + if runLength > 0 { + e.emitHuff(h, 0x00) + } + return dc +} + +// toYCbCr converts the 8x8 region of m whose top-left corner is p to its +// YCbCr values. +func toYCbCr(m image.Image, p image.Point, yBlock, cbBlock, crBlock *block) { + b := m.Bounds() + xmax := b.Max.X - 1 + ymax := b.Max.Y - 1 + for j := 0; j < 8; j++ { + for i := 0; i < 8; i++ { + r, g, b, _ := m.At(min(p.X+i, xmax), min(p.Y+j, ymax)).RGBA() + yy, cb, cr := ycbcr.RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8)) + yBlock[8*j+i] = int(yy) + cbBlock[8*j+i] = int(cb) + crBlock[8*j+i] = int(cr) + } + } +} + +// rgbaToYCbCr is a specialized version of toYCbCr for image.RGBA images. +func rgbaToYCbCr(m *image.RGBA, p image.Point, yBlock, cbBlock, crBlock *block) { + b := m.Bounds() + xmax := b.Max.X - 1 + ymax := b.Max.Y - 1 + for j := 0; j < 8; j++ { + sj := p.Y + j + if sj > ymax { + sj = ymax + } + yoff := sj * m.Stride + for i := 0; i < 8; i++ { + sx := p.X + i + if sx > xmax { + sx = xmax + } + col := &m.Pix[yoff+sx] + yy, cb, cr := ycbcr.RGBToYCbCr(col.R, col.G, col.B) + yBlock[8*j+i] = int(yy) + cbBlock[8*j+i] = int(cb) + crBlock[8*j+i] = int(cr) + } + } +} + +// scale scales the 16x16 region represented by the 4 src blocks to the 8x8 +// dst block. +func scale(dst *block, src *[4]block) { + for i := 0; i < 4; i++ { + dstOff := (i&2)<<4 | (i&1)<<2 + for y := 0; y < 4; y++ { + for x := 0; x < 4; x++ { + j := 16*y + 2*x + sum := src[i][j] + src[i][j+1] + src[i][j+8] + src[i][j+9] + dst[8*y+x+dstOff] = (sum + 2) >> 2 + } + } + } +} + +// sosHeader is the SOS marker "\xff\xda" followed by 12 bytes: +// - the marker length "\x00\x0c", +// - the number of components "\x03", +// - component 1 uses DC table 0 and AC table 0 "\x01\x00", +// - component 2 uses DC table 1 and AC table 1 "\x02\x11", +// - component 3 uses DC table 1 and AC table 1 "\x03\x11", +// - padding "\x00\x00\x00". +var sosHeader = []byte{ + 0xff, 0xda, 0x00, 0x0c, 0x03, 0x01, 0x00, 0x02, + 0x11, 0x03, 0x11, 0x00, 0x00, 0x00, +} + +// writeSOS writes the StartOfScan marker. +func (e *encoder) writeSOS(m image.Image) { + e.write(sosHeader) + var ( + // Scratch buffers to hold the YCbCr values. + yBlock block + cbBlock [4]block + crBlock [4]block + cBlock block + // DC components are delta-encoded. + prevDCY, prevDCCb, prevDCCr int + ) + bounds := m.Bounds() + rgba, _ := m.(*image.RGBA) + for y := bounds.Min.Y; y < bounds.Max.Y; y += 16 { + for x := bounds.Min.X; x < bounds.Max.X; x += 16 { + for i := 0; i < 4; i++ { + xOff := (i & 1) * 8 + yOff := (i & 2) * 4 + p := image.Point{x + xOff, y + yOff} + if rgba != nil { + rgbaToYCbCr(rgba, p, &yBlock, &cbBlock[i], &crBlock[i]) + } else { + toYCbCr(m, p, &yBlock, &cbBlock[i], &crBlock[i]) + } + prevDCY = e.writeBlock(&yBlock, 0, prevDCY) + } + scale(&cBlock, &cbBlock) + prevDCCb = e.writeBlock(&cBlock, 1, prevDCCb) + scale(&cBlock, &crBlock) + prevDCCr = e.writeBlock(&cBlock, 1, prevDCCr) + } + } + // Pad the last byte with 1's. + e.emit(0x7f, 7) +} + +// DefaultQuality is the default quality encoding parameter. +const DefaultQuality = 75 + +// Options are the encoding parameters. +// Quality ranges from 1 to 100 inclusive, higher is better. +type Options struct { + Quality int +} + +// Encode writes the Image m to w in JPEG 4:2:0 baseline format with the given +// options. Default parameters are used if a nil *Options is passed. +func Encode(w io.Writer, m image.Image, o *Options) os.Error { + b := m.Bounds() + if b.Dx() >= 1<<16 || b.Dy() >= 1<<16 { + return os.NewError("jpeg: image is too large to encode") + } + var e encoder + if ww, ok := w.(writer); ok { + e.w = ww + } else { + e.w = bufio.NewWriter(w) + } + // Clip quality to [1, 100]. + quality := DefaultQuality + if o != nil { + quality = o.Quality + if quality < 1 { + quality = 1 + } else if quality > 100 { + quality = 100 + } + } + // Convert from a quality rating to a scaling factor. + var scale int + if quality < 50 { + scale = 5000 / quality + } else { + scale = 200 - quality*2 + } + // Initialize the quantization tables. + for i := range e.quant { + for j := range e.quant[i] { + x := int(unscaledQuant[i][j]) + x = (x*scale + 50) / 100 + if x < 1 { + x = 1 + } else if x > 255 { + x = 255 + } + e.quant[i][j] = uint8(x) + } + } + // Write the Start Of Image marker. + e.buf[0] = 0xff + e.buf[1] = 0xd8 + e.write(e.buf[:2]) + // Write the quantization tables. + e.writeDQT() + // Write the image dimensions. + e.writeSOF0(b.Size()) + // Write the Huffman tables. + e.writeDHT() + // Write the image data. + e.writeSOS(m) + // Write the End Of Image marker. + e.buf[0] = 0xff + e.buf[1] = 0xd9 + e.write(e.buf[:2]) + e.flush() + return e.err +} diff --git a/libgo/go/image/jpeg/writer_test.go b/libgo/go/image/jpeg/writer_test.go new file mode 100644 index 00000000000..7aec70f016e --- /dev/null +++ b/libgo/go/image/jpeg/writer_test.go @@ -0,0 +1,115 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "bytes" + "image" + "image/png" + "io/ioutil" + "rand" + "os" + "testing" +) + +var testCase = []struct { + filename string + quality int + tolerance int64 +}{ + {"../testdata/video-001.png", 1, 24 << 8}, + {"../testdata/video-001.png", 20, 12 << 8}, + {"../testdata/video-001.png", 60, 8 << 8}, + {"../testdata/video-001.png", 80, 6 << 8}, + {"../testdata/video-001.png", 90, 4 << 8}, + {"../testdata/video-001.png", 100, 2 << 8}, +} + +func delta(u0, u1 uint32) int64 { + d := int64(u0) - int64(u1) + if d < 0 { + return -d + } + return d +} + +func readPng(filename string) (image.Image, os.Error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return png.Decode(f) +} + +func TestWriter(t *testing.T) { + for _, tc := range testCase { + // Read the image. + m0, err := readPng(tc.filename) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Encode that image as JPEG. + buf := bytes.NewBuffer(nil) + err = Encode(buf, m0, &Options{Quality: tc.quality}) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Decode that JPEG. + m1, err := Decode(buf) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Compute the average delta in RGB space. + b := m0.Bounds() + var sum, n int64 + for y := b.Min.Y; y < b.Max.Y; y++ { + for x := b.Min.X; x < b.Max.X; x++ { + c0 := m0.At(x, y) + c1 := m1.At(x, y) + r0, g0, b0, _ := c0.RGBA() + r1, g1, b1, _ := c1.RGBA() + sum += delta(r0, r1) + sum += delta(g0, g1) + sum += delta(b0, b1) + n += 3 + } + } + // Compare the average delta to the tolerance level. + if sum/n > tc.tolerance { + t.Errorf("%s, quality=%d: average delta is too high", tc.filename, tc.quality) + continue + } + } +} + +func BenchmarkEncodeRGBOpaque(b *testing.B) { + b.StopTimer() + img := image.NewRGBA(640, 480) + // Set all pixels to 0xFF alpha to force opaque mode. + bo := img.Bounds() + rnd := rand.New(rand.NewSource(123)) + for y := bo.Min.Y; y < bo.Max.Y; y++ { + for x := bo.Min.X; x < bo.Max.X; x++ { + img.Set(x, y, image.RGBAColor{ + uint8(rnd.Intn(256)), + uint8(rnd.Intn(256)), + uint8(rnd.Intn(256)), + 255}) + } + } + if !img.Opaque() { + panic("expected image to be opaque") + } + b.SetBytes(640 * 480 * 4) + b.StartTimer() + options := &Options{Quality: 90} + for i := 0; i < b.N; i++ { + Encode(ioutil.Discard, img, options) + } +} diff --git a/libgo/go/image/png/reader.go b/libgo/go/image/png/reader.go index eee4eac2e15..8c76afa72c6 100644 --- a/libgo/go/image/png/reader.go +++ b/libgo/go/image/png/reader.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The png package implements a PNG image decoder and encoder. +// Package png implements a PNG image decoder and encoder. // // The PNG specification is at http://www.libpng.org/pub/png/spec/1.2/PNG-Contents.html package png @@ -378,7 +378,7 @@ func (d *decoder) idatReader(idat io.Reader) (image.Image, os.Error) { for x := 0; x < d.width; x += 8 { b := cdat[x/8] for x2 := 0; x2 < 8 && x+x2 < d.width; x2++ { - gray.Set(x+x2, y, image.GrayColor{(b >> 7) * 0xff}) + gray.SetGray(x+x2, y, image.GrayColor{(b >> 7) * 0xff}) b <<= 1 } } @@ -386,7 +386,7 @@ func (d *decoder) idatReader(idat io.Reader) (image.Image, os.Error) { for x := 0; x < d.width; x += 4 { b := cdat[x/4] for x2 := 0; x2 < 4 && x+x2 < d.width; x2++ { - gray.Set(x+x2, y, image.GrayColor{(b >> 6) * 0x55}) + gray.SetGray(x+x2, y, image.GrayColor{(b >> 6) * 0x55}) b <<= 2 } } @@ -394,22 +394,22 @@ func (d *decoder) idatReader(idat io.Reader) (image.Image, os.Error) { for x := 0; x < d.width; x += 2 { b := cdat[x/2] for x2 := 0; x2 < 2 && x+x2 < d.width; x2++ { - gray.Set(x+x2, y, image.GrayColor{(b >> 4) * 0x11}) + gray.SetGray(x+x2, y, image.GrayColor{(b >> 4) * 0x11}) b <<= 4 } } case cbG8: for x := 0; x < d.width; x++ { - gray.Set(x, y, image.GrayColor{cdat[x]}) + gray.SetGray(x, y, image.GrayColor{cdat[x]}) } case cbGA8: for x := 0; x < d.width; x++ { ycol := cdat[2*x+0] - nrgba.Set(x, y, image.NRGBAColor{ycol, ycol, ycol, cdat[2*x+1]}) + nrgba.SetNRGBA(x, y, image.NRGBAColor{ycol, ycol, ycol, cdat[2*x+1]}) } case cbTC8: for x := 0; x < d.width; x++ { - rgba.Set(x, y, image.RGBAColor{cdat[3*x+0], cdat[3*x+1], cdat[3*x+2], 0xff}) + rgba.SetRGBA(x, y, image.RGBAColor{cdat[3*x+0], cdat[3*x+1], cdat[3*x+2], 0xff}) } case cbP1: for x := 0; x < d.width; x += 8 { @@ -456,25 +456,25 @@ func (d *decoder) idatReader(idat io.Reader) (image.Image, os.Error) { } case cbTCA8: for x := 0; x < d.width; x++ { - nrgba.Set(x, y, image.NRGBAColor{cdat[4*x+0], cdat[4*x+1], cdat[4*x+2], cdat[4*x+3]}) + nrgba.SetNRGBA(x, y, image.NRGBAColor{cdat[4*x+0], cdat[4*x+1], cdat[4*x+2], cdat[4*x+3]}) } case cbG16: for x := 0; x < d.width; x++ { ycol := uint16(cdat[2*x+0])<<8 | uint16(cdat[2*x+1]) - gray16.Set(x, y, image.Gray16Color{ycol}) + gray16.SetGray16(x, y, image.Gray16Color{ycol}) } case cbGA16: for x := 0; x < d.width; x++ { ycol := uint16(cdat[4*x+0])<<8 | uint16(cdat[4*x+1]) acol := uint16(cdat[4*x+2])<<8 | uint16(cdat[4*x+3]) - nrgba64.Set(x, y, image.NRGBA64Color{ycol, ycol, ycol, acol}) + nrgba64.SetNRGBA64(x, y, image.NRGBA64Color{ycol, ycol, ycol, acol}) } case cbTC16: for x := 0; x < d.width; x++ { rcol := uint16(cdat[6*x+0])<<8 | uint16(cdat[6*x+1]) gcol := uint16(cdat[6*x+2])<<8 | uint16(cdat[6*x+3]) bcol := uint16(cdat[6*x+4])<<8 | uint16(cdat[6*x+5]) - rgba64.Set(x, y, image.RGBA64Color{rcol, gcol, bcol, 0xffff}) + rgba64.SetRGBA64(x, y, image.RGBA64Color{rcol, gcol, bcol, 0xffff}) } case cbTCA16: for x := 0; x < d.width; x++ { @@ -482,7 +482,7 @@ func (d *decoder) idatReader(idat io.Reader) (image.Image, os.Error) { gcol := uint16(cdat[8*x+2])<<8 | uint16(cdat[8*x+3]) bcol := uint16(cdat[8*x+4])<<8 | uint16(cdat[8*x+5]) acol := uint16(cdat[8*x+6])<<8 | uint16(cdat[8*x+7]) - nrgba64.Set(x, y, image.NRGBA64Color{rcol, gcol, bcol, acol}) + nrgba64.SetNRGBA64(x, y, image.NRGBA64Color{rcol, gcol, bcol, acol}) } } diff --git a/libgo/go/image/png/reader_test.go b/libgo/go/image/png/reader_test.go index efa6336d792..bcc1a3db475 100644 --- a/libgo/go/image/png/reader_test.go +++ b/libgo/go/image/png/reader_test.go @@ -28,6 +28,7 @@ var filenames = []string{ "basn3p02", "basn3p04", "basn3p08", + "basn3p08-trns", "basn4a08", "basn4a16", "basn6a08", @@ -98,17 +99,30 @@ func sng(w io.WriteCloser, filename string, png image.Image) { // (the PNG spec section 11.3 says "Ancillary chunks may be ignored by a decoder"). io.WriteString(w, "gAMA {1.0000}\n") - // Write the PLTE (if applicable). + // Write the PLTE and tRNS (if applicable). if cpm != nil { + lastAlpha := -1 io.WriteString(w, "PLTE {\n") - for i := 0; i < len(cpm); i++ { - r, g, b, _ := cpm[i].RGBA() + for i, c := range cpm { + r, g, b, a := c.RGBA() + if a != 0xffff { + lastAlpha = i + } r >>= 8 g >>= 8 b >>= 8 fmt.Fprintf(w, " (%3d,%3d,%3d) # rgb = (0x%02x,0x%02x,0x%02x)\n", r, g, b, r, g, b) } io.WriteString(w, "}\n") + if lastAlpha != -1 { + io.WriteString(w, "tRNS {\n") + for i := 0; i <= lastAlpha; i++ { + _, _, _, a := cpm[i].RGBA() + a >>= 8 + fmt.Fprintf(w, " %d", a) + } + io.WriteString(w, "}\n") + } } // Write the IMAGE. diff --git a/libgo/go/image/png/testdata/pngsuite/README b/libgo/go/image/png/testdata/pngsuite/README index abe3ecb201d..c0f78bde87a 100644 --- a/libgo/go/image/png/testdata/pngsuite/README +++ b/libgo/go/image/png/testdata/pngsuite/README @@ -10,6 +10,9 @@ The files basn0g01-30.png, basn0g02-29.png and basn0g04-31.png are in fact not part of pngsuite but were created from files in pngsuite. Their non-power- of-two sizes makes them useful for testing bit-depths smaller than a byte. +basn3a08.png was generated from basn6a08.png using the pngnq tool, which +converted it to the 8-bit paletted image with alpha values in tRNS chunk. + The *.sng files in this directory were generated from the *.png files by the sng command-line tool and some hand editing. The files basn0g0{1,2,4}.sng were actually generated by first converting the PNG diff --git a/libgo/go/image/png/writer.go b/libgo/go/image/png/writer.go index 081d06bf571..a27586f2394 100644 --- a/libgo/go/image/png/writer.go +++ b/libgo/go/image/png/writer.go @@ -130,12 +130,8 @@ func (e *encoder) writePLTE(p image.PalettedColorModel) { e.err = FormatError("bad palette length: " + strconv.Itoa(len(p))) return } - for i := 0; i < len(p); i++ { - r, g, b, a := p[i].RGBA() - if a != 0xffff { - e.err = UnsupportedError("non-opaque palette color") - return - } + for i, c := range p { + r, g, b, _ := c.RGBA() e.tmp[3*i+0] = uint8(r >> 8) e.tmp[3*i+1] = uint8(g >> 8) e.tmp[3*i+2] = uint8(b >> 8) @@ -143,6 +139,21 @@ func (e *encoder) writePLTE(p image.PalettedColorModel) { e.writeChunk(e.tmp[0:3*len(p)], "PLTE") } +func (e *encoder) maybeWritetRNS(p image.PalettedColorModel) { + last := -1 + for i, c := range p { + _, _, _, a := c.RGBA() + if a != 0xffff { + last = i + } + e.tmp[i] = uint8(a >> 8) + } + if last == -1 { + return + } + e.writeChunk(e.tmp[:last+1], "tRNS") +} + // An encoder is an io.Writer that satisfies writes by writing PNG IDAT chunks, // including an 8-byte header and 4-byte CRC checksum per Write call. Such calls // should be relatively infrequent, since writeIDATs uses a bufio.Writer. @@ -263,7 +274,12 @@ func writeImage(w io.Writer, m image.Image, cb int) os.Error { defer zw.Close() bpp := 0 // Bytes per pixel. + + // Used by fast paths for common image types var paletted *image.Paletted + var rgba *image.RGBA + rgba, _ = m.(*image.RGBA) + switch cb { case cbG8: bpp = 1 @@ -303,12 +319,24 @@ func writeImage(w io.Writer, m image.Image, cb int) os.Error { cr[0][x+1] = c.Y } case cbTC8: - for x := b.Min.X; x < b.Max.X; x++ { - // We have previously verified that the alpha value is fully opaque. - r, g, b, _ := m.At(x, y).RGBA() - cr[0][3*x+1] = uint8(r >> 8) - cr[0][3*x+2] = uint8(g >> 8) - cr[0][3*x+3] = uint8(b >> 8) + // We have previously verified that the alpha value is fully opaque. + cr0 := cr[0] + if rgba != nil { + yoff := y * rgba.Stride + xoff := 3*b.Min.X + 1 + for _, color := range rgba.Pix[yoff+b.Min.X : yoff+b.Max.X] { + cr0[xoff] = color.R + cr0[xoff+1] = color.G + cr0[xoff+2] = color.B + xoff += 3 + } + } else { + for x := b.Min.X; x < b.Max.X; x++ { + r, g, b, _ := m.At(x, y).RGBA() + cr0[3*x+1] = uint8(r >> 8) + cr0[3*x+2] = uint8(g >> 8) + cr0[3*x+3] = uint8(b >> 8) + } } case cbP8: rowOffset := y * paletted.Stride @@ -430,6 +458,7 @@ func Encode(w io.Writer, m image.Image) os.Error { e.writeIHDR() if pal != nil { e.writePLTE(pal.Palette) + e.maybeWritetRNS(pal.Palette) } e.writeIDATs() e.writeIEND() diff --git a/libgo/go/image/png/writer_test.go b/libgo/go/image/png/writer_test.go index 4d9929f314f..6b054aaa893 100644 --- a/libgo/go/image/png/writer_test.go +++ b/libgo/go/image/png/writer_test.go @@ -5,10 +5,10 @@ package png import ( - "bytes" "fmt" "image" "io" + "io/ioutil" "os" "testing" ) @@ -81,10 +81,42 @@ func BenchmarkEncodePaletted(b *testing.B) { image.RGBAColor{0, 0, 0, 255}, image.RGBAColor{255, 255, 255, 255}, }) + b.SetBytes(640 * 480 * 1) b.StartTimer() - buffer := new(bytes.Buffer) for i := 0; i < b.N; i++ { - buffer.Reset() - Encode(buffer, img) + Encode(ioutil.Discard, img) + } +} + +func BenchmarkEncodeRGBOpaque(b *testing.B) { + b.StopTimer() + img := image.NewRGBA(640, 480) + // Set all pixels to 0xFF alpha to force opaque mode. + bo := img.Bounds() + for y := bo.Min.Y; y < bo.Max.Y; y++ { + for x := bo.Min.X; x < bo.Max.X; x++ { + img.Set(x, y, image.RGBAColor{0, 0, 0, 255}) + } + } + if !img.Opaque() { + panic("expected image to be opaque") + } + b.SetBytes(640 * 480 * 4) + b.StartTimer() + for i := 0; i < b.N; i++ { + Encode(ioutil.Discard, img) + } +} + +func BenchmarkEncodeRGBA(b *testing.B) { + b.StopTimer() + img := image.NewRGBA(640, 480) + if img.Opaque() { + panic("expected image to not be opaque") + } + b.SetBytes(640 * 480 * 4) + b.StartTimer() + for i := 0; i < b.N; i++ { + Encode(ioutil.Discard, img) } } diff --git a/libgo/go/image/testdata/video-001.bmp b/libgo/go/image/testdata/video-001.bmp new file mode 100644 index 00000000000..ca3dd42a7c9 Binary files /dev/null and b/libgo/go/image/testdata/video-001.bmp differ diff --git a/libgo/go/image/testdata/video-001.gif b/libgo/go/image/testdata/video-001.gif new file mode 100644 index 00000000000..ca06af61bba Binary files /dev/null and b/libgo/go/image/testdata/video-001.gif differ diff --git a/libgo/go/image/testdata/video-001.jpeg b/libgo/go/image/testdata/video-001.jpeg new file mode 100644 index 00000000000..1b87c933bb7 Binary files /dev/null and b/libgo/go/image/testdata/video-001.jpeg differ diff --git a/libgo/go/image/testdata/video-001.png b/libgo/go/image/testdata/video-001.png new file mode 100644 index 00000000000..d3468bbe8fc Binary files /dev/null and b/libgo/go/image/testdata/video-001.png differ diff --git a/libgo/go/image/testdata/video-001.tiff b/libgo/go/image/testdata/video-001.tiff new file mode 100644 index 00000000000..0dd6cd93133 Binary files /dev/null and b/libgo/go/image/testdata/video-001.tiff differ diff --git a/libgo/go/image/tiff/buffer.go b/libgo/go/image/tiff/buffer.go new file mode 100644 index 00000000000..7c0714225f1 --- /dev/null +++ b/libgo/go/image/tiff/buffer.go @@ -0,0 +1,57 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tiff + +import ( + "io" + "os" +) + +// buffer buffers an io.Reader to satisfy io.ReaderAt. +type buffer struct { + r io.Reader + buf []byte +} + +func (b *buffer) ReadAt(p []byte, off int64) (int, os.Error) { + o := int(off) + end := o + len(p) + if int64(end) != off+int64(len(p)) { + return 0, os.EINVAL + } + + m := len(b.buf) + if end > m { + if end > cap(b.buf) { + newcap := 1024 + for newcap < end { + newcap *= 2 + } + newbuf := make([]byte, end, newcap) + copy(newbuf, b.buf) + b.buf = newbuf + } else { + b.buf = b.buf[:end] + } + if n, err := io.ReadFull(b.r, b.buf[m:end]); err != nil { + end = m + n + b.buf = b.buf[:end] + return copy(p, b.buf[o:end]), err + } + } + + return copy(p, b.buf[o:end]), nil +} + +// newReaderAt converts an io.Reader into an io.ReaderAt. +func newReaderAt(r io.Reader) io.ReaderAt { + if ra, ok := r.(io.ReaderAt); ok { + return ra + } + return &buffer{ + r: r, + buf: make([]byte, 0, 1024), + } +} diff --git a/libgo/go/image/tiff/buffer_test.go b/libgo/go/image/tiff/buffer_test.go new file mode 100644 index 00000000000..4f3e68e838c --- /dev/null +++ b/libgo/go/image/tiff/buffer_test.go @@ -0,0 +1,36 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tiff + +import ( + "os" + "strings" + "testing" +) + +var readAtTests = []struct { + n int + off int64 + s string + err os.Error +}{ + {2, 0, "ab", nil}, + {6, 0, "abcdef", nil}, + {3, 3, "def", nil}, + {3, 5, "f", os.EOF}, + {3, 6, "", os.EOF}, +} + +func TestReadAt(t *testing.T) { + r := newReaderAt(strings.NewReader("abcdef")) + b := make([]byte, 10) + for _, test := range readAtTests { + n, err := r.ReadAt(b[:test.n], test.off) + s := string(b[:n]) + if s != test.s || err != test.err { + t.Errorf("buffer.ReadAt(<%v bytes>, %v): got %v, %q; want %v, %q", test.n, test.off, err, s, test.err, test.s) + } + } +} diff --git a/libgo/go/image/tiff/consts.go b/libgo/go/image/tiff/consts.go new file mode 100644 index 00000000000..761ac9d9094 --- /dev/null +++ b/libgo/go/image/tiff/consts.go @@ -0,0 +1,102 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tiff + +// A tiff image file contains one or more images. The metadata +// of each image is contained in an Image File Directory (IFD), +// which contains entries of 12 bytes each and is described +// on page 14-16 of the specification. An IFD entry consists of +// +// - a tag, which describes the signification of the entry, +// - the data type and length of the entry, +// - the data itself or a pointer to it if it is more than 4 bytes. +// +// The presence of a length means that each IFD is effectively an array. + +const ( + leHeader = "II\x2A\x00" // Header for little-endian files. + beHeader = "MM\x00\x2A" // Header for big-endian files. + + ifdLen = 12 // Length of an IFD entry in bytes. +) + +// Data types (p. 14-16 of the spec). +const ( + dtByte = 1 + dtASCII = 2 + dtShort = 3 + dtLong = 4 + dtRational = 5 +) + +// The length of one instance of each data type in bytes. +var lengths = [...]uint32{0, 1, 1, 2, 4, 8} + +// Tags (see p. 28-41 of the spec). +const ( + tImageWidth = 256 + tImageLength = 257 + tBitsPerSample = 258 + tCompression = 259 + tPhotometricInterpretation = 262 + + tStripOffsets = 273 + tSamplesPerPixel = 277 + tRowsPerStrip = 278 + tStripByteCounts = 279 + + tXResolution = 282 + tYResolution = 283 + tResolutionUnit = 296 + + tPredictor = 317 + tColorMap = 320 + tExtraSamples = 338 +) + +// Compression types (defined in various places in the spec and supplements). +const ( + cNone = 1 + cCCITT = 2 + cG3 = 3 // Group 3 Fax. + cG4 = 4 // Group 4 Fax. + cLZW = 5 + cJPEGOld = 6 // Superseded by cJPEG. + cJPEG = 7 + cDeflate = 8 // zlib compression. + cPackBits = 32773 + cDeflateOld = 32946 // Superseded by cDeflate. +) + +// Photometric interpretation values (see p. 37 of the spec). +const ( + pWhiteIsZero = 0 + pBlackIsZero = 1 + pRGB = 2 + pPaletted = 3 + pTransMask = 4 // transparency mask + pCMYK = 5 + pYCbCr = 6 + pCIELab = 8 +) + +// Values for the tPredictor tag (page 64-65 of the spec). +const ( + prNone = 1 + prHorizontal = 2 +) + +// imageMode represents the mode of the image. +type imageMode int + +const ( + mBilevel imageMode = iota + mPaletted + mGray + mGrayInvert + mRGB + mRGBA + mNRGBA +) diff --git a/libgo/go/image/tiff/reader.go b/libgo/go/image/tiff/reader.go new file mode 100644 index 00000000000..40f659c36c8 --- /dev/null +++ b/libgo/go/image/tiff/reader.go @@ -0,0 +1,385 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tiff implements a TIFF image decoder. +// +// The TIFF specification is at http://partners.adobe.com/public/developer/en/tiff/TIFF6.pdf +package tiff + +import ( + "compress/lzw" + "compress/zlib" + "encoding/binary" + "image" + "io" + "io/ioutil" + "os" +) + +// A FormatError reports that the input is not a valid TIFF image. +type FormatError string + +func (e FormatError) String() string { + return "tiff: invalid format: " + string(e) +} + +// An UnsupportedError reports that the input uses a valid but +// unimplemented feature. +type UnsupportedError string + +func (e UnsupportedError) String() string { + return "tiff: unsupported feature: " + string(e) +} + +// An InternalError reports that an internal error was encountered. +type InternalError string + +func (e InternalError) String() string { + return "tiff: internal error: " + string(e) +} + +type decoder struct { + r io.ReaderAt + byteOrder binary.ByteOrder + config image.Config + mode imageMode + features map[int][]uint + palette []image.Color +} + +// firstVal returns the first uint of the features entry with the given tag, +// or 0 if the tag does not exist. +func (d *decoder) firstVal(tag int) uint { + f := d.features[tag] + if len(f) == 0 { + return 0 + } + return f[0] +} + +// ifdUint decodes the IFD entry in p, which must be of the Byte, Short +// or Long type, and returns the decoded uint values. +func (d *decoder) ifdUint(p []byte) (u []uint, err os.Error) { + var raw []byte + datatype := d.byteOrder.Uint16(p[2:4]) + count := d.byteOrder.Uint32(p[4:8]) + if datalen := lengths[datatype] * count; datalen > 4 { + // The IFD contains a pointer to the real value. + raw = make([]byte, datalen) + _, err = d.r.ReadAt(raw, int64(d.byteOrder.Uint32(p[8:12]))) + } else { + raw = p[8 : 8+datalen] + } + if err != nil { + return nil, err + } + + u = make([]uint, count) + switch datatype { + case dtByte: + for i := uint32(0); i < count; i++ { + u[i] = uint(raw[i]) + } + case dtShort: + for i := uint32(0); i < count; i++ { + u[i] = uint(d.byteOrder.Uint16(raw[2*i : 2*(i+1)])) + } + case dtLong: + for i := uint32(0); i < count; i++ { + u[i] = uint(d.byteOrder.Uint32(raw[4*i : 4*(i+1)])) + } + default: + return nil, UnsupportedError("data type") + } + return u, nil +} + +// parseIFD decides whether the the IFD entry in p is "interesting" and +// stows away the data in the decoder. +func (d *decoder) parseIFD(p []byte) os.Error { + tag := d.byteOrder.Uint16(p[0:2]) + switch tag { + case tBitsPerSample, + tExtraSamples, + tPhotometricInterpretation, + tCompression, + tPredictor, + tStripOffsets, + tStripByteCounts, + tRowsPerStrip, + tImageLength, + tImageWidth: + val, err := d.ifdUint(p) + if err != nil { + return err + } + d.features[int(tag)] = val + case tColorMap: + val, err := d.ifdUint(p) + if err != nil { + return err + } + numcolors := len(val) / 3 + if len(val)%3 != 0 || numcolors <= 0 || numcolors > 256 { + return FormatError("bad ColorMap length") + } + d.palette = make([]image.Color, numcolors) + for i := 0; i < numcolors; i++ { + d.palette[i] = image.RGBA64Color{ + uint16(val[i]), + uint16(val[i+numcolors]), + uint16(val[i+2*numcolors]), + 0xffff, + } + } + } + return nil +} + +// decode decodes the raw data of an image with 8 bits in each sample. +// It reads from p and writes the strip with ymin <= y < ymax into dst. +func (d *decoder) decode(dst image.Image, p []byte, ymin, ymax int) os.Error { + spp := len(d.features[tBitsPerSample]) // samples per pixel + off := 0 + width := dst.Bounds().Dx() + + if len(p) < spp*(ymax-ymin)*width { + return FormatError("short data strip") + } + + // Apply horizontal predictor if necessary. + // In this case, p contains the color difference to the preceding pixel. + // See page 64-65 of the spec. + if d.firstVal(tPredictor) == prHorizontal { + for y := ymin; y < ymax; y++ { + off += spp + for x := 0; x < (width-1)*spp; x++ { + p[off] += p[off-spp] + off++ + } + } + off = 0 + } + + switch d.mode { + case mGray: + img := dst.(*image.Gray) + for y := ymin; y < ymax; y++ { + for x := img.Rect.Min.X; x < img.Rect.Max.X; x++ { + img.Set(x, y, image.GrayColor{p[off]}) + off += spp + } + } + case mGrayInvert: + img := dst.(*image.Gray) + for y := ymin; y < ymax; y++ { + for x := img.Rect.Min.X; x < img.Rect.Max.X; x++ { + img.Set(x, y, image.GrayColor{0xff - p[off]}) + off += spp + } + } + case mPaletted: + img := dst.(*image.Paletted) + for y := ymin; y < ymax; y++ { + for x := img.Rect.Min.X; x < img.Rect.Max.X; x++ { + img.SetColorIndex(x, y, p[off]) + off += spp + } + } + case mRGB: + img := dst.(*image.RGBA) + for y := ymin; y < ymax; y++ { + for x := img.Rect.Min.X; x < img.Rect.Max.X; x++ { + img.Set(x, y, image.RGBAColor{p[off], p[off+1], p[off+2], 0xff}) + off += spp + } + } + case mNRGBA: + img := dst.(*image.NRGBA) + for y := ymin; y < ymax; y++ { + for x := img.Rect.Min.X; x < img.Rect.Max.X; x++ { + img.Set(x, y, image.NRGBAColor{p[off], p[off+1], p[off+2], p[off+3]}) + off += spp + } + } + case mRGBA: + img := dst.(*image.RGBA) + for y := ymin; y < ymax; y++ { + for x := img.Rect.Min.X; x < img.Rect.Max.X; x++ { + img.Set(x, y, image.RGBAColor{p[off], p[off+1], p[off+2], p[off+3]}) + off += spp + } + } + } + + return nil +} + +func newDecoder(r io.Reader) (*decoder, os.Error) { + d := &decoder{ + r: newReaderAt(r), + features: make(map[int][]uint), + } + + p := make([]byte, 8) + if _, err := d.r.ReadAt(p, 0); err != nil { + return nil, err + } + switch string(p[0:4]) { + case leHeader: + d.byteOrder = binary.LittleEndian + case beHeader: + d.byteOrder = binary.BigEndian + default: + return nil, FormatError("malformed header") + } + + ifdOffset := int64(d.byteOrder.Uint32(p[4:8])) + + // The first two bytes contain the number of entries (12 bytes each). + if _, err := d.r.ReadAt(p[0:2], ifdOffset); err != nil { + return nil, err + } + numItems := int(d.byteOrder.Uint16(p[0:2])) + + // All IFD entries are read in one chunk. + p = make([]byte, ifdLen*numItems) + if _, err := d.r.ReadAt(p, ifdOffset+2); err != nil { + return nil, err + } + + for i := 0; i < len(p); i += ifdLen { + if err := d.parseIFD(p[i : i+ifdLen]); err != nil { + return nil, err + } + } + + d.config.Width = int(d.firstVal(tImageWidth)) + d.config.Height = int(d.firstVal(tImageLength)) + + // Determine the image mode. + switch d.firstVal(tPhotometricInterpretation) { + case pRGB: + d.config.ColorModel = image.RGBAColorModel + // RGB images normally have 3 samples per pixel. + // If there are more, ExtraSamples (p. 31-32 of the spec) + // gives their meaning (usually an alpha channel). + switch len(d.features[tBitsPerSample]) { + case 3: + d.mode = mRGB + case 4: + switch d.firstVal(tExtraSamples) { + case 1: + d.mode = mRGBA + case 2: + d.mode = mNRGBA + d.config.ColorModel = image.NRGBAColorModel + default: + // The extra sample is discarded. + d.mode = mRGB + } + default: + return nil, FormatError("wrong number of samples for RGB") + } + case pPaletted: + d.mode = mPaletted + d.config.ColorModel = image.PalettedColorModel(d.palette) + case pWhiteIsZero: + d.mode = mGrayInvert + d.config.ColorModel = image.GrayColorModel + case pBlackIsZero: + d.mode = mGray + d.config.ColorModel = image.GrayColorModel + default: + return nil, UnsupportedError("color model") + } + + if _, ok := d.features[tBitsPerSample]; !ok { + return nil, FormatError("BitsPerSample tag missing") + } + for _, b := range d.features[tBitsPerSample] { + if b != 8 { + return nil, UnsupportedError("not an 8-bit image") + } + } + + return d, nil +} + +// DecodeConfig returns the color model and dimensions of a TIFF image without +// decoding the entire image. +func DecodeConfig(r io.Reader) (image.Config, os.Error) { + d, err := newDecoder(r) + if err != nil { + return image.Config{}, err + } + return d.config, nil +} + +// Decode reads a TIFF image from r and returns it as an image.Image. +// The type of Image returned depends on the contents of the TIFF. +func Decode(r io.Reader) (img image.Image, err os.Error) { + d, err := newDecoder(r) + if err != nil { + return + } + + // Check if we have the right number of strips, offsets and counts. + rps := int(d.firstVal(tRowsPerStrip)) + numStrips := (d.config.Height + rps - 1) / rps + if rps == 0 || len(d.features[tStripOffsets]) < numStrips || len(d.features[tStripByteCounts]) < numStrips { + return nil, FormatError("inconsistent header") + } + + switch d.mode { + case mGray, mGrayInvert: + img = image.NewGray(d.config.Width, d.config.Height) + case mPaletted: + img = image.NewPaletted(d.config.Width, d.config.Height, d.palette) + case mNRGBA: + img = image.NewNRGBA(d.config.Width, d.config.Height) + case mRGB, mRGBA: + img = image.NewRGBA(d.config.Width, d.config.Height) + } + + var p []byte + for i := 0; i < numStrips; i++ { + ymin := i * rps + // The last strip may be shorter. + if i == numStrips-1 && d.config.Height%rps != 0 { + rps = d.config.Height % rps + } + offset := int64(d.features[tStripOffsets][i]) + n := int64(d.features[tStripByteCounts][i]) + switch d.firstVal(tCompression) { + case cNone: + // TODO(bsiegert): Avoid copy if r is a tiff.buffer. + p = make([]byte, 0, n) + _, err = d.r.ReadAt(p, offset) + case cLZW: + r := lzw.NewReader(io.NewSectionReader(d.r, offset, n), lzw.MSB, 8) + p, err = ioutil.ReadAll(r) + r.Close() + case cDeflate, cDeflateOld: + r, err := zlib.NewReader(io.NewSectionReader(d.r, offset, n)) + if err != nil { + return nil, err + } + p, err = ioutil.ReadAll(r) + r.Close() + default: + err = UnsupportedError("compression") + } + if err != nil { + return + } + err = d.decode(img, p, ymin, ymin+rps) + } + return +} + +func init() { + image.RegisterFormat("tiff", leHeader, Decode, DecodeConfig) + image.RegisterFormat("tiff", beHeader, Decode, DecodeConfig) +} diff --git a/libgo/go/image/ycbcr/ycbcr.go b/libgo/go/image/ycbcr/ycbcr.go new file mode 100644 index 00000000000..cda45996df0 --- /dev/null +++ b/libgo/go/image/ycbcr/ycbcr.go @@ -0,0 +1,174 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ycbcr provides images from the Y'CbCr color model. +// +// JPEG, VP8, the MPEG family and other codecs use this color model. Such +// codecs often use the terms YUV and Y'CbCr interchangeably, but strictly +// speaking, the term YUV applies only to analog video signals. +// +// Conversion between RGB and Y'CbCr is lossy and there are multiple, slightly +// different formulae for converting between the two. This package follows +// the JFIF specification at http://www.w3.org/Graphics/JPEG/jfif3.pdf. +package ycbcr + +import ( + "image" +) + +// RGBToYCbCr converts an RGB triple to a YCbCr triple. All components lie +// within the range [0, 255]. +func RGBToYCbCr(r, g, b uint8) (uint8, uint8, uint8) { + // The JFIF specification says: + // Y' = 0.2990*R + 0.5870*G + 0.1140*B + // Cb = -0.1687*R - 0.3313*G + 0.5000*B + 128 + // Cr = 0.5000*R - 0.4187*G - 0.0813*B + 128 + // http://www.w3.org/Graphics/JPEG/jfif3.pdf says Y but means Y'. + r1 := int(r) + g1 := int(g) + b1 := int(b) + yy := (19595*r1 + 38470*g1 + 7471*b1 + 1<<15) >> 16 + cb := (-11056*r1 - 21712*g1 + 32768*b1 + 257<<15) >> 16 + cr := (32768*r1 - 27440*g1 - 5328*b1 + 257<<15) >> 16 + if yy < 0 { + yy = 0 + } else if yy > 255 { + yy = 255 + } + if cb < 0 { + cb = 0 + } else if cb > 255 { + cb = 255 + } + if cr < 0 { + cr = 0 + } else if cr > 255 { + cr = 255 + } + return uint8(yy), uint8(cb), uint8(cr) +} + +// YCbCrToRGB converts a YCbCr triple to an RGB triple. All components lie +// within the range [0, 255]. +func YCbCrToRGB(y, cb, cr uint8) (uint8, uint8, uint8) { + // The JFIF specification says: + // R = Y' + 1.40200*(Cr-128) + // G = Y' - 0.34414*(Cb-128) - 0.71414*(Cr-128) + // B = Y' + 1.77200*(Cb-128) + // http://www.w3.org/Graphics/JPEG/jfif3.pdf says Y but means Y'. + yy1 := int(y)<<16 + 1<<15 + cb1 := int(cb) - 128 + cr1 := int(cr) - 128 + r := (yy1 + 91881*cr1) >> 16 + g := (yy1 - 22554*cb1 - 46802*cr1) >> 16 + b := (yy1 + 116130*cb1) >> 16 + if r < 0 { + r = 0 + } else if r > 255 { + r = 255 + } + if g < 0 { + g = 0 + } else if g > 255 { + g = 255 + } + if b < 0 { + b = 0 + } else if b > 255 { + b = 255 + } + return uint8(r), uint8(g), uint8(b) +} + +// YCbCrColor represents a fully opaque 24-bit Y'CbCr color, having 8 bits for +// each of one luma and two chroma components. +type YCbCrColor struct { + Y, Cb, Cr uint8 +} + +func (c YCbCrColor) RGBA() (uint32, uint32, uint32, uint32) { + r, g, b := YCbCrToRGB(c.Y, c.Cb, c.Cr) + return uint32(r) * 0x101, uint32(g) * 0x101, uint32(b) * 0x101, 0xffff +} + +func toYCbCrColor(c image.Color) image.Color { + if _, ok := c.(YCbCrColor); ok { + return c + } + r, g, b, _ := c.RGBA() + y, u, v := RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8)) + return YCbCrColor{y, u, v} +} + +// YCbCrColorModel is the color model for YCbCrColor. +var YCbCrColorModel image.ColorModel = image.ColorModelFunc(toYCbCrColor) + +// SubsampleRatio is the chroma subsample ratio used in a YCbCr image. +type SubsampleRatio int + +const ( + SubsampleRatio444 SubsampleRatio = iota + SubsampleRatio422 + SubsampleRatio420 +) + +// YCbCr is an in-memory image of YCbCr colors. There is one Y sample per pixel, +// but each Cb and Cr sample can span one or more pixels. +// YStride is the Y slice index delta between vertically adjacent pixels. +// CStride is the Cb and Cr slice index delta between vertically adjacent pixels +// that map to separate chroma samples. +// It is not an absolute requirement, but YStride and len(Y) are typically +// multiples of 8, and: +// For 4:4:4, CStride == YStride/1 && len(Cb) == len(Cr) == len(Y)/1. +// For 4:2:2, CStride == YStride/2 && len(Cb) == len(Cr) == len(Y)/2. +// For 4:2:0, CStride == YStride/2 && len(Cb) == len(Cr) == len(Y)/4. +type YCbCr struct { + Y []uint8 + Cb []uint8 + Cr []uint8 + YStride int + CStride int + SubsampleRatio SubsampleRatio + Rect image.Rectangle +} + +func (p *YCbCr) ColorModel() image.ColorModel { + return YCbCrColorModel +} + +func (p *YCbCr) Bounds() image.Rectangle { + return p.Rect +} + +func (p *YCbCr) At(x, y int) image.Color { + if !p.Rect.Contains(image.Point{x, y}) { + return YCbCrColor{} + } + switch p.SubsampleRatio { + case SubsampleRatio422: + i := x / 2 + return YCbCrColor{ + p.Y[y*p.YStride+x], + p.Cb[y*p.CStride+i], + p.Cr[y*p.CStride+i], + } + case SubsampleRatio420: + i, j := x/2, y/2 + return YCbCrColor{ + p.Y[y*p.YStride+x], + p.Cb[j*p.CStride+i], + p.Cr[j*p.CStride+i], + } + } + // Default to 4:4:4 subsampling. + return YCbCrColor{ + p.Y[y*p.YStride+x], + p.Cb[y*p.CStride+x], + p.Cr[y*p.CStride+x], + } +} + +func (p *YCbCr) Opaque() bool { + return true +} diff --git a/libgo/go/image/ycbcr/ycbcr_test.go b/libgo/go/image/ycbcr/ycbcr_test.go new file mode 100644 index 00000000000..2e60a6f61f9 --- /dev/null +++ b/libgo/go/image/ycbcr/ycbcr_test.go @@ -0,0 +1,33 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ycbcr + +import ( + "testing" +) + +func delta(x, y uint8) uint8 { + if x >= y { + return x - y + } + return y - x +} + +// Test that a subset of RGB space can be converted to YCbCr and back to within +// 1/256 tolerance. +func TestRoundtrip(t *testing.T) { + for r := 0; r < 255; r += 7 { + for g := 0; g < 255; g += 5 { + for b := 0; b < 255; b += 3 { + r0, g0, b0 := uint8(r), uint8(g), uint8(b) + y, cb, cr := RGBToYCbCr(r0, g0, b0) + r1, g1, b1 := YCbCrToRGB(y, cb, cr) + if delta(r0, r1) > 1 || delta(g0, g1) > 1 || delta(b0, b1) > 1 { + t.Fatalf("r0, g0, b0 = %d, %d, %d r1, g1, b1 = %d, %d, %d", r0, g0, b0, r1, g1, b1) + } + } + } + } +} diff --git a/libgo/go/index/suffixarray/suffixarray.go b/libgo/go/index/suffixarray/suffixarray.go index d8c6fc91b48..079b7d8ed0b 100644 --- a/libgo/go/index/suffixarray/suffixarray.go +++ b/libgo/go/index/suffixarray/suffixarray.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The suffixarray package implements substring search in logarithmic time -// using an in-memory suffix array. +// Package suffixarray implements substring search in logarithmic time using +// an in-memory suffix array. // // Example use: // diff --git a/libgo/go/io/io.go b/libgo/go/io/io.go index 3b879189798..0bc73d67dd9 100644 --- a/libgo/go/io/io.go +++ b/libgo/go/io/io.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides basic interfaces to I/O primitives. +// Package io provides basic interfaces to I/O primitives. // Its primary job is to wrap existing implementations of such primitives, // such as those in package os, into shared public interfaces that // abstract the functionality, plus some other related primitives. @@ -136,6 +136,10 @@ type WriterTo interface { // At the end of the input stream, ReadAt returns 0, os.EOF. // ReadAt may return a non-zero number of bytes with a non-nil err. // In particular, a ReadAt that exhausts the input may return n > 0, os.EOF. +// +// If ReadAt is reading from an data stream with a seek offset, +// ReadAt should not affect nor be affected by the underlying +// seek offset. type ReaderAt interface { ReadAt(p []byte, off int64) (n int, err os.Error) } @@ -182,16 +186,16 @@ func ReadAtLeast(r Reader, buf []byte, min int) (n int, err os.Error) { if len(buf) < min { return 0, ErrShortBuffer } - for n < min { - nn, e := r.Read(buf[n:]) - if nn > 0 { - n += nn - } - if e != nil { - if e == os.EOF && n > 0 { - e = ErrUnexpectedEOF - } - return n, e + for n < min && err == nil { + var nn int + nn, err = r.Read(buf[n:]) + n += nn + } + if err == os.EOF { + if n >= min { + err = nil + } else if n > 0 { + err = ErrUnexpectedEOF } } return diff --git a/libgo/go/io/io_test.go b/libgo/go/io/io_test.go index 4fcd85e693e..bc4f354af40 100644 --- a/libgo/go/io/io_test.go +++ b/libgo/go/io/io_test.go @@ -118,27 +118,50 @@ func TestCopynEOF(t *testing.T) { func TestReadAtLeast(t *testing.T) { var rb bytes.Buffer + testReadAtLeast(t, &rb) +} + +// A version of bytes.Buffer that returns n > 0, os.EOF on Read +// when the input is exhausted. +type dataAndEOFBuffer struct { + bytes.Buffer +} + +func (r *dataAndEOFBuffer) Read(p []byte) (n int, err os.Error) { + n, err = r.Buffer.Read(p) + if n > 0 && r.Buffer.Len() == 0 && err == nil { + err = os.EOF + } + return +} + +func TestReadAtLeastWithDataAndEOF(t *testing.T) { + var rb dataAndEOFBuffer + testReadAtLeast(t, &rb) +} + +func testReadAtLeast(t *testing.T, rb ReadWriter) { rb.Write([]byte("0123")) buf := make([]byte, 2) - n, err := ReadAtLeast(&rb, buf, 2) + n, err := ReadAtLeast(rb, buf, 2) if err != nil { t.Error(err) } - n, err = ReadAtLeast(&rb, buf, 4) + n, err = ReadAtLeast(rb, buf, 4) if err != ErrShortBuffer { t.Errorf("expected ErrShortBuffer got %v", err) } if n != 0 { t.Errorf("expected to have read 0 bytes, got %v", n) } - n, err = ReadAtLeast(&rb, buf, 1) + n, err = ReadAtLeast(rb, buf, 1) if err != nil { t.Error(err) } if n != 2 { t.Errorf("expected to have read 2 bytes, got %v", n) } - n, err = ReadAtLeast(&rb, buf, 2) + n, err = ReadAtLeast(rb, buf, 2) if err != os.EOF { t.Errorf("expected EOF, got %v", err) } @@ -146,7 +169,7 @@ func TestReadAtLeast(t *testing.T) { t.Errorf("expected to have read 0 bytes, got %v", n) } rb.Write([]byte("4")) - n, err = ReadAtLeast(&rb, buf, 2) + n, err = ReadAtLeast(rb, buf, 2) if err != ErrUnexpectedEOF { t.Errorf("expected ErrUnexpectedEOF, got %v", err) } diff --git a/libgo/go/io/ioutil/ioutil.go b/libgo/go/io/ioutil/ioutil.go index 57d797e851c..5f1eecaabed 100644 --- a/libgo/go/io/ioutil/ioutil.go +++ b/libgo/go/io/ioutil/ioutil.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Utility functions. - +// Package ioutil implements some I/O utility functions. package ioutil import ( @@ -102,3 +101,13 @@ func (nopCloser) Close() os.Error { return nil } func NopCloser(r io.Reader) io.ReadCloser { return nopCloser{r} } + +type devNull int + +func (devNull) Write(p []byte) (int, os.Error) { + return len(p), nil +} + +// Discard is an io.Writer on which all Write calls succeed +// without doing anything. +var Discard io.Writer = devNull(0) diff --git a/libgo/go/io/multi.go b/libgo/go/io/multi.go index 88e4f1b7698..d702d46c725 100644 --- a/libgo/go/io/multi.go +++ b/libgo/go/io/multi.go @@ -15,10 +15,8 @@ func (mr *multiReader) Read(p []byte) (n int, err os.Error) { n, err = mr.readers[0].Read(p) if n > 0 || err != os.EOF { if err == os.EOF { - // This shouldn't happen. - // Well-behaved Readers should never - // return non-zero bytes read with an - // EOF. But if so, we clean it. + // Don't return EOF yet. There may be more bytes + // in the remaining readers. err = nil } return diff --git a/libgo/go/json/decode.go b/libgo/go/json/decode.go index 501230c0c05..e78b60ccb54 100644 --- a/libgo/go/json/decode.go +++ b/libgo/go/json/decode.go @@ -87,7 +87,7 @@ func (e *UnmarshalTypeError) String() string { // led to an unexported (and therefore unwritable) struct field. type UnmarshalFieldError struct { Key string - Type *reflect.StructType + Type reflect.Type Field reflect.StructField } @@ -106,7 +106,7 @@ func (e *InvalidUnmarshalError) String() string { return "json: Unmarshal(nil)" } - if _, ok := e.Type.(*reflect.PtrType); !ok { + if e.Type.Kind() != reflect.Ptr { return "json: Unmarshal(non-pointer " + e.Type.String() + ")" } return "json: Unmarshal(nil " + e.Type.String() + ")" @@ -122,10 +122,10 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) { } }() - rv := reflect.NewValue(v) - pv, ok := rv.(*reflect.PtrValue) - if !ok || pv.IsNil() { - return &InvalidUnmarshalError{reflect.Typeof(v)} + rv := reflect.ValueOf(v) + pv := rv + if pv.Kind() != reflect.Ptr || pv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} } d.scan.reset() @@ -215,7 +215,7 @@ func (d *decodeState) scanWhile(op int) int { // value decodes a JSON value from d.data[d.off:] into the value. // it updates d.off to point past the decoded value. func (d *decodeState) value(v reflect.Value) { - if v == nil { + if !v.IsValid() { _, rest, err := nextValue(d.data[d.off:], &d.nextscan) if err != nil { d.error(err) @@ -262,20 +262,21 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl _, isUnmarshaler = v.Interface().(Unmarshaler) } - if iv, ok := v.(*reflect.InterfaceValue); ok && !iv.IsNil() { + if iv := v; iv.Kind() == reflect.Interface && !iv.IsNil() { v = iv.Elem() continue } - pv, ok := v.(*reflect.PtrValue) - if !ok { + + pv := v + if pv.Kind() != reflect.Ptr { break } - _, isptrptr := pv.Elem().(*reflect.PtrValue) - if !isptrptr && wantptr && !isUnmarshaler { + + if pv.Elem().Kind() != reflect.Ptr && wantptr && pv.CanSet() && !isUnmarshaler { return nil, pv } if pv.IsNil() { - pv.PointTo(reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem())) + pv.Set(reflect.New(pv.Type().Elem())) } if isUnmarshaler { // Using v.Interface().(Unmarshaler) @@ -286,7 +287,7 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl // This is an unfortunate consequence of reflect. // An alternative would be to look up the // UnmarshalJSON method and return a FuncValue. - return v.Interface().(Unmarshaler), nil + return v.Interface().(Unmarshaler), reflect.Value{} } v = pv.Elem() } @@ -309,22 +310,23 @@ func (d *decodeState) array(v reflect.Value) { v = pv // Decoding into nil interface? Switch to non-reflect code. - iv, ok := v.(*reflect.InterfaceValue) + iv := v + ok := iv.Kind() == reflect.Interface if ok { - iv.Set(reflect.NewValue(d.arrayInterface())) + iv.Set(reflect.ValueOf(d.arrayInterface())) return } // Check type of target. - av, ok := v.(reflect.ArrayOrSliceValue) - if !ok { + av := v + if av.Kind() != reflect.Array && av.Kind() != reflect.Slice { d.saveError(&UnmarshalTypeError{"array", v.Type()}) d.off-- d.next() return } - sv, _ := v.(*reflect.SliceValue) + sv := v i := 0 for { @@ -339,26 +341,26 @@ func (d *decodeState) array(v reflect.Value) { d.scan.undo(op) // Get element of array, growing if necessary. - if i >= av.Cap() && sv != nil { + if i >= av.Cap() && sv.IsValid() { newcap := sv.Cap() + sv.Cap()/2 if newcap < 4 { newcap = 4 } - newv := reflect.MakeSlice(sv.Type().(*reflect.SliceType), sv.Len(), newcap) + newv := reflect.MakeSlice(sv.Type(), sv.Len(), newcap) reflect.Copy(newv, sv) sv.Set(newv) } - if i >= av.Len() && sv != nil { + if i >= av.Len() && sv.IsValid() { // Must be slice; gave up on array during i >= av.Cap(). sv.SetLen(i + 1) } // Decode into element. if i < av.Len() { - d.value(av.Elem(i)) + d.value(av.Index(i)) } else { // Ran out of fixed array: skip. - d.value(nil) + d.value(reflect.Value{}) } i++ @@ -372,11 +374,11 @@ func (d *decodeState) array(v reflect.Value) { } } if i < av.Len() { - if sv == nil { + if !sv.IsValid() { // Array. Zero the rest. - z := reflect.MakeZero(av.Type().(*reflect.ArrayType).Elem()) + z := reflect.Zero(av.Type().Elem()) for ; i < av.Len(); i++ { - av.Elem(i).SetValue(z) + av.Index(i).Set(z) } } else { sv.SetLen(i) @@ -405,41 +407,43 @@ func (d *decodeState) object(v reflect.Value) { v = pv // Decoding into nil interface? Switch to non-reflect code. - iv, ok := v.(*reflect.InterfaceValue) - if ok { - iv.Set(reflect.NewValue(d.objectInterface())) + iv := v + if iv.Kind() == reflect.Interface { + iv.Set(reflect.ValueOf(d.objectInterface())) return } // Check type of target: struct or map[string]T var ( - mv *reflect.MapValue - sv *reflect.StructValue + mv reflect.Value + sv reflect.Value ) - switch v := v.(type) { - case *reflect.MapValue: + switch v.Kind() { + case reflect.Map: // map must have string type - t := v.Type().(*reflect.MapType) - if t.Key() != reflect.Typeof("") { + t := v.Type() + if t.Key() != reflect.TypeOf("") { d.saveError(&UnmarshalTypeError{"object", v.Type()}) break } mv = v if mv.IsNil() { - mv.SetValue(reflect.MakeMap(t)) + mv.Set(reflect.MakeMap(t)) } - case *reflect.StructValue: + case reflect.Struct: sv = v default: d.saveError(&UnmarshalTypeError{"object", v.Type()}) } - if mv == nil && sv == nil { + if !mv.IsValid() && !sv.IsValid() { d.off-- d.next() // skip over { } in input return } + var mapElem reflect.Value + for { // Read opening " of string key or closing }. op := d.scanWhile(scanSkipSpace) @@ -462,12 +466,18 @@ func (d *decodeState) object(v reflect.Value) { // Figure out field corresponding to key. var subv reflect.Value - if mv != nil { - subv = reflect.MakeZero(mv.Type().(*reflect.MapType).Elem()) + if mv.IsValid() { + elemType := mv.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem } else { var f reflect.StructField var ok bool - st := sv.Type().(*reflect.StructType) + st := sv.Type() // First try for field with that tag. if isValidTag(key) { for i := 0; i < sv.NumField(); i++ { @@ -510,8 +520,8 @@ func (d *decodeState) object(v reflect.Value) { // Write value back to map; // if using struct, subv points into struct already. - if mv != nil { - mv.SetElem(reflect.NewValue(key), subv) + if mv.IsValid() { + mv.SetMapIndex(reflect.ValueOf(key), subv) } // Next token must be , or }. @@ -552,22 +562,22 @@ func (d *decodeState) literal(v reflect.Value) { switch c := item[0]; c { case 'n': // null - switch v.(type) { + switch v.Kind() { default: d.saveError(&UnmarshalTypeError{"null", v.Type()}) - case *reflect.InterfaceValue, *reflect.PtrValue, *reflect.MapValue: - v.SetValue(nil) + case reflect.Interface, reflect.Ptr, reflect.Map: + v.Set(reflect.Zero(v.Type())) } case 't', 'f': // true, false value := c == 't' - switch v := v.(type) { + switch v.Kind() { default: d.saveError(&UnmarshalTypeError{"bool", v.Type()}) - case *reflect.BoolValue: - v.Set(value) - case *reflect.InterfaceValue: - v.Set(reflect.NewValue(value)) + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + v.Set(reflect.ValueOf(value)) } case '"': // string @@ -575,10 +585,10 @@ func (d *decodeState) literal(v reflect.Value) { if !ok { d.error(errPhase) } - switch v := v.(type) { + switch v.Kind() { default: d.saveError(&UnmarshalTypeError{"string", v.Type()}) - case *reflect.SliceValue: + case reflect.Slice: if v.Type() != byteSliceType { d.saveError(&UnmarshalTypeError{"string", v.Type()}) break @@ -589,11 +599,11 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(err) break } - v.Set(reflect.NewValue(b[0:n]).(*reflect.SliceValue)) - case *reflect.StringValue: - v.Set(string(s)) - case *reflect.InterfaceValue: - v.Set(reflect.NewValue(string(s))) + v.Set(reflect.ValueOf(b[0:n])) + case reflect.String: + v.SetString(string(s)) + case reflect.Interface: + v.Set(reflect.ValueOf(string(s))) } default: // number @@ -601,40 +611,40 @@ func (d *decodeState) literal(v reflect.Value) { d.error(errPhase) } s := string(item) - switch v := v.(type) { + switch v.Kind() { default: d.error(&UnmarshalTypeError{"number", v.Type()}) - case *reflect.InterfaceValue: + case reflect.Interface: n, err := strconv.Atof64(s) if err != nil { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(reflect.NewValue(n)) + v.Set(reflect.ValueOf(n)) - case *reflect.IntValue: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.Atoi64(s) - if err != nil || v.Overflow(n) { + if err != nil || v.OverflowInt(n) { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(n) + v.SetInt(n) - case *reflect.UintValue: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: n, err := strconv.Atoui64(s) - if err != nil || v.Overflow(n) { + if err != nil || v.OverflowUint(n) { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(n) + v.SetUint(n) - case *reflect.FloatValue: + case reflect.Float32, reflect.Float64: n, err := strconv.AtofN(s, v.Type().Bits()) - if err != nil || v.Overflow(n) { + if err != nil || v.OverflowFloat(n) { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(n) + v.SetFloat(n) } } } @@ -764,7 +774,7 @@ func (d *decodeState) literalInterface() interface{} { } n, err := strconv.Atof64(string(item)) if err != nil { - d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)}) + d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.TypeOf(0.0)}) } return n } diff --git a/libgo/go/json/decode_test.go b/libgo/go/json/decode_test.go index aad8b635f2b..bf8bf10bf89 100644 --- a/libgo/go/json/decode_test.go +++ b/libgo/go/json/decode_test.go @@ -21,7 +21,7 @@ type tx struct { x int } -var txType = reflect.Typeof((*tx)(nil)).(*reflect.PtrType).Elem().(*reflect.StructType) +var txType = reflect.TypeOf((*tx)(nil)).Elem() // A type that can unmarshal itself. @@ -64,14 +64,14 @@ var unmarshalTests = []unmarshalTest{ {`"g-clef: \uD834\uDD1E"`, new(string), "g-clef: \U0001D11E", nil}, {`"invalid: \uD834x\uDD1E"`, new(string), "invalid: \uFFFDx\uFFFD", nil}, {"null", new(interface{}), nil, nil}, - {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.Typeof("")}}, + {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.TypeOf("")}}, {`{"x": 1}`, new(tx), tx{}, &UnmarshalFieldError{"x", txType, txType.Field(0)}}, // skip invalid tags {`{"X":"a", "y":"b", "Z":"c"}`, new(badTag), badTag{"a", "b", "c"}, nil}, // syntax errors - {`{"X": "foo", "Y"}`, nil, nil, SyntaxError("invalid character '}' after object key")}, + {`{"X": "foo", "Y"}`, nil, nil, &SyntaxError{"invalid character '}' after object key", 17}}, // composite tests {allValueIndent, new(All), allValue, nil}, @@ -125,12 +125,12 @@ func TestMarshalBadUTF8(t *testing.T) { } func TestUnmarshal(t *testing.T) { - var scan scanner for i, tt := range unmarshalTests { + var scan scanner in := []byte(tt.in) if err := checkValid(in, &scan); err != nil { if !reflect.DeepEqual(err, tt.err) { - t.Errorf("#%d: checkValid: %v", i, err) + t.Errorf("#%d: checkValid: %#v", i, err) continue } } @@ -138,8 +138,7 @@ func TestUnmarshal(t *testing.T) { continue } // v = new(right-type) - v := reflect.NewValue(tt.ptr).(*reflect.PtrValue) - v.PointTo(reflect.MakeZero(v.Type().(*reflect.PtrType).Elem())) + v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) { t.Errorf("#%d: %v want %v", i, err, tt.err) continue diff --git a/libgo/go/json/encode.go b/libgo/go/json/encode.go index 26ce47039f6..ec0a14a6a4d 100644 --- a/libgo/go/json/encode.go +++ b/libgo/go/json/encode.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The json package implements encoding and decoding of JSON objects as -// defined in RFC 4627. +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. package json import ( @@ -172,7 +172,7 @@ func (e *encodeState) marshal(v interface{}) (err os.Error) { err = r.(os.Error) } }() - e.reflectValue(reflect.NewValue(v)) + e.reflectValue(reflect.ValueOf(v)) return nil } @@ -180,10 +180,10 @@ func (e *encodeState) error(err os.Error) { panic(err) } -var byteSliceType = reflect.Typeof([]byte(nil)) +var byteSliceType = reflect.TypeOf([]byte(nil)) func (e *encodeState) reflectValue(v reflect.Value) { - if v == nil { + if !v.IsValid() { e.WriteString("null") return } @@ -200,30 +200,30 @@ func (e *encodeState) reflectValue(v reflect.Value) { return } - switch v := v.(type) { - case *reflect.BoolValue: - x := v.Get() + switch v.Kind() { + case reflect.Bool: + x := v.Bool() if x { e.WriteString("true") } else { e.WriteString("false") } - case *reflect.IntValue: - e.WriteString(strconv.Itoa64(v.Get())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + e.WriteString(strconv.Itoa64(v.Int())) - case *reflect.UintValue: - e.WriteString(strconv.Uitoa64(v.Get())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + e.WriteString(strconv.Uitoa64(v.Uint())) - case *reflect.FloatValue: - e.WriteString(strconv.FtoaN(v.Get(), 'g', -1, v.Type().Bits())) + case reflect.Float32, reflect.Float64: + e.WriteString(strconv.FtoaN(v.Float(), 'g', -1, v.Type().Bits())) - case *reflect.StringValue: - e.string(v.Get()) + case reflect.String: + e.string(v.String()) - case *reflect.StructValue: + case reflect.Struct: e.WriteByte('{') - t := v.Type().(*reflect.StructType) + t := v.Type() n := v.NumField() first := true for i := 0; i < n; i++ { @@ -246,8 +246,8 @@ func (e *encodeState) reflectValue(v reflect.Value) { } e.WriteByte('}') - case *reflect.MapValue: - if _, ok := v.Type().(*reflect.MapType).Key().(*reflect.StringType); !ok { + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { e.error(&UnsupportedTypeError{v.Type()}) } if v.IsNil() { @@ -255,19 +255,19 @@ func (e *encodeState) reflectValue(v reflect.Value) { break } e.WriteByte('{') - var sv stringValues = v.Keys() + var sv stringValues = v.MapKeys() sort.Sort(sv) for i, k := range sv { if i > 0 { e.WriteByte(',') } - e.string(k.(*reflect.StringValue).Get()) + e.string(k.String()) e.WriteByte(':') - e.reflectValue(v.Elem(k)) + e.reflectValue(v.MapIndex(k)) } e.WriteByte('}') - case reflect.ArrayOrSliceValue: + case reflect.Array, reflect.Slice: if v.Type() == byteSliceType { e.WriteByte('"') s := v.Interface().([]byte) @@ -292,11 +292,11 @@ func (e *encodeState) reflectValue(v reflect.Value) { if i > 0 { e.WriteByte(',') } - e.reflectValue(v.Elem(i)) + e.reflectValue(v.Index(i)) } e.WriteByte(']') - case interfaceOrPtrValue: + case reflect.Interface, reflect.Ptr: if v.IsNil() { e.WriteString("null") return @@ -328,7 +328,7 @@ type stringValues []reflect.Value func (sv stringValues) Len() int { return len(sv) } func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } -func (sv stringValues) get(i int) string { return sv[i].(*reflect.StringValue).Get() } +func (sv stringValues) get(i int) string { return sv[i].String() } func (e *encodeState) string(s string) { e.WriteByte('"') diff --git a/libgo/go/json/scanner.go b/libgo/go/json/scanner.go index e98ddef5cc1..49c2edd5453 100644 --- a/libgo/go/json/scanner.go +++ b/libgo/go/json/scanner.go @@ -23,6 +23,7 @@ import ( func checkValid(data []byte, scan *scanner) os.Error { scan.reset() for _, c := range data { + scan.bytes++ if scan.step(scan, int(c)) == scanError { return scan.err } @@ -56,10 +57,12 @@ func nextValue(data []byte, scan *scanner) (value, rest []byte, err os.Error) { } // A SyntaxError is a description of a JSON syntax error. -type SyntaxError string - -func (e SyntaxError) String() string { return string(e) } +type SyntaxError struct { + msg string // description of error + Offset int64 // error occurred after reading Offset bytes +} +func (e *SyntaxError) String() string { return e.msg } // A scanner is a JSON scanning state machine. // Callers call scan.reset() and then pass bytes in one at a time @@ -89,6 +92,9 @@ type scanner struct { // 1-byte redo (see undo method) redoCode int redoState func(*scanner, int) int + + // total bytes consumed, updated by decoder.Decode + bytes int64 } // These values are returned by the state transition functions @@ -148,7 +154,7 @@ func (s *scanner) eof() int { return scanEnd } if s.err == nil { - s.err = SyntaxError("unexpected end of JSON input") + s.err = &SyntaxError{"unexpected end of JSON input", s.bytes} } return scanError } @@ -581,7 +587,7 @@ func stateError(s *scanner, c int) int { // error records an error and switches to the error state. func (s *scanner) error(c int, context string) int { s.step = stateError - s.err = SyntaxError("invalid character " + quoteChar(c) + " " + context) + s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes} return scanError } diff --git a/libgo/go/json/stream.go b/libgo/go/json/stream.go index cb9b16559ed..f143b3f0ade 100644 --- a/libgo/go/json/stream.go +++ b/libgo/go/json/stream.go @@ -23,8 +23,8 @@ func NewDecoder(r io.Reader) *Decoder { return &Decoder{r: r} } -// Decode reads the next JSON-encoded value from the -// connection and stores it in the value pointed to by v. +// Decode reads the next JSON-encoded value from its +// input and stores it in the value pointed to by v. // // See the documentation for Unmarshal for details about // the conversion of JSON into a Go value. @@ -62,6 +62,7 @@ Input: for { // Look in the buffer for a new value. for i, c := range dec.buf[scanp:] { + dec.scan.bytes++ v := dec.scan.step(&dec.scan, int(c)) if v == scanEnd { scanp += i diff --git a/libgo/go/log/log.go b/libgo/go/log/log.go index 33140ee08af..00bce6a17dc 100644 --- a/libgo/go/log/log.go +++ b/libgo/go/log/log.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Simple logging package. It defines a type, Logger, with methods -// for formatting output. It also has a predefined 'standard' Logger -// accessible through helper functions Print[f|ln], Fatal[f|ln], and +// Package log implements a simple logging package. It defines a type, Logger, +// with methods for formatting output. It also has a predefined 'standard' +// Logger accessible through helper functions Print[f|ln], Fatal[f|ln], and // Panic[f|ln], which are easier to use than creating a Logger manually. // That logger writes to standard error and prints the date and time // of each logged message. diff --git a/libgo/go/math/const.go b/libgo/go/math/const.go index b53527a4f39..a108d3e294d 100644 --- a/libgo/go/math/const.go +++ b/libgo/go/math/const.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The math package provides basic constants and mathematical functions. +// Package math provides basic constants and mathematical functions. package math // Mathematical constants. diff --git a/libgo/go/mime/mediatype.go b/libgo/go/mime/mediatype.go index eb629aa6f7f..f28ff3e9681 100644 --- a/libgo/go/mime/mediatype.go +++ b/libgo/go/mime/mediatype.go @@ -6,10 +6,30 @@ package mime import ( "bytes" + "fmt" + "os" "strings" "unicode" ) +func validMediaTypeOrDisposition(s string) bool { + typ, rest := consumeToken(s) + if typ == "" { + return false + } + if rest == "" { + return true + } + if !strings.HasPrefix(rest, "/") { + return false + } + subtype, rest := consumeToken(rest[1:]) + if subtype == "" { + return false + } + return rest == "" +} + // ParseMediaType parses a media type value and any optional // parameters, per RFC 1531. Media types are the values in // Content-Type and Content-Disposition headers (RFC 2183). On @@ -22,25 +42,112 @@ func ParseMediaType(v string) (mediatype string, params map[string]string) { i = len(v) } mediatype = strings.TrimSpace(strings.ToLower(v[0:i])) + if !validMediaTypeOrDisposition(mediatype) { + return "", nil + } + params = make(map[string]string) + // Map of base parameter name -> parameter name -> value + // for parameters containing a '*' character. + // Lazily initialized. + var continuation map[string]map[string]string + v = v[i:] for len(v) > 0 { v = strings.TrimLeftFunc(v, unicode.IsSpace) if len(v) == 0 { - return + break } key, value, rest := consumeMediaParam(v) if key == "" { + if strings.TrimSpace(rest) == ";" { + // Ignore trailing semicolons. + // Not an error. + return + } // Parse error. return "", nil } - params[key] = value + + pmap := params + if idx := strings.Index(key, "*"); idx != -1 { + baseName := key[:idx] + if continuation == nil { + continuation = make(map[string]map[string]string) + } + var ok bool + if pmap, ok = continuation[baseName]; !ok { + continuation[baseName] = make(map[string]string) + pmap = continuation[baseName] + } + } + if _, exists := pmap[key]; exists { + // Duplicate parameter name is bogus. + return "", nil + } + pmap[key] = value v = rest } + + // Stitch together any continuations or things with stars + // (i.e. RFC 2231 things with stars: "foo*0" or "foo*") + var buf bytes.Buffer + for key, pieceMap := range continuation { + singlePartKey := key + "*" + if v, ok := pieceMap[singlePartKey]; ok { + decv := decode2231Enc(v) + params[key] = decv + continue + } + + buf.Reset() + valid := false + for n := 0; ; n++ { + simplePart := fmt.Sprintf("%s*%d", key, n) + if v, ok := pieceMap[simplePart]; ok { + valid = true + buf.WriteString(v) + continue + } + encodedPart := simplePart + "*" + if v, ok := pieceMap[encodedPart]; ok { + valid = true + if n == 0 { + buf.WriteString(decode2231Enc(v)) + } else { + decv, _ := percentHexUnescape(v) + buf.WriteString(decv) + } + } else { + break + } + } + if valid { + params[key] = buf.String() + } + } + return } +func decode2231Enc(v string) string { + sv := strings.Split(v, "'", 3) + if len(sv) != 3 { + return "" + } + // TODO: ignoring lang in sv[1] for now. If anybody needs it we'll + // need to decide how to expose it in the API. But I'm not sure + // anybody uses it in practice. + charset := strings.ToLower(sv[0]) + if charset != "us-ascii" && charset != "utf-8" { + // TODO: unsupported encoding + return "" + } + encv, _ := percentHexUnescape(sv[2]) + return encv +} + func isNotTokenChar(rune int) bool { return !IsTokenChar(rune) } @@ -66,10 +173,12 @@ func consumeToken(v string) (token, rest string) { // quoted-string) and the rest of the string. On failure, returns // ("", v). func consumeValue(v string) (value, rest string) { - if !strings.HasPrefix(v, `"`) { + if !strings.HasPrefix(v, `"`) && !strings.HasPrefix(v, `'`) { return consumeToken(v) } + leadQuote := int(v[0]) + // parse a quoted-string rest = v[1:] // consume the leading quote buffer := new(bytes.Buffer) @@ -78,17 +187,14 @@ func consumeValue(v string) (value, rest string) { for idx, rune = range rest { switch { case nextIsLiteral: - if rune >= 0x80 { - return "", v - } buffer.WriteRune(rune) nextIsLiteral = false - case rune == '"': + case rune == leadQuote: return buffer.String(), rest[idx+1:] - case IsQText(rune): - buffer.WriteRune(rune) case rune == '\\': nextIsLiteral = true + case rune != '\r' && rune != '\n': + buffer.WriteRune(rune) default: return "", v } @@ -108,13 +214,79 @@ func consumeMediaParam(v string) (param, value, rest string) { if param == "" { return "", "", v } + + rest = strings.TrimLeftFunc(rest, unicode.IsSpace) if !strings.HasPrefix(rest, "=") { return "", "", v } rest = rest[1:] // consume equals sign + rest = strings.TrimLeftFunc(rest, unicode.IsSpace) value, rest = consumeValue(rest) if value == "" { return "", "", v } return param, value, rest } + +func percentHexUnescape(s string) (string, os.Error) { + // Count %, check that they're well-formed. + percents := 0 + for i := 0; i < len(s); { + if s[i] != '%' { + i++ + continue + } + percents++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[0:3] + } + return "", fmt.Errorf("mime: bogus characters after %%: %q", s) + } + i += 3 + } + if percents == 0 { + return s, nil + } + + t := make([]byte, len(s)-2*percents) + j := 0 + for i := 0; i < len(s); { + switch s[i] { + case '%': + t[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + j++ + i += 3 + default: + t[j] = s[i] + j++ + i++ + } + } + return string(t), nil +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} diff --git a/libgo/go/mime/mediatype_test.go b/libgo/go/mime/mediatype_test.go index 4891e899d4c..454ddd03778 100644 --- a/libgo/go/mime/mediatype_test.go +++ b/libgo/go/mime/mediatype_test.go @@ -5,6 +5,7 @@ package mime import ( + "reflect" "testing" ) @@ -85,23 +86,152 @@ func TestConsumeMediaParam(t *testing.T) { } } +type mediaTypeTest struct { + in string + t string + p map[string]string +} + func TestParseMediaType(t *testing.T) { - tests := [...]string{ - `form-data; name="foo"`, - ` form-data ; name=foo`, - `FORM-DATA;name="foo"`, - ` FORM-DATA ; name="foo"`, - ` FORM-DATA ; name="foo"`, - `form-data; key=value; blah="value";name="foo" `, + // Convenience map initializer + m := func(s ...string) map[string]string { + sm := make(map[string]string) + for i := 0; i < len(s); i += 2 { + sm[s[i]] = s[i+1] + } + return sm + } + + nameFoo := map[string]string{"name": "foo"} + tests := []mediaTypeTest{ + {`form-data; name="foo"`, "form-data", nameFoo}, + {` form-data ; name=foo`, "form-data", nameFoo}, + {`FORM-DATA;name="foo"`, "form-data", nameFoo}, + {` FORM-DATA ; name="foo"`, "form-data", nameFoo}, + {` FORM-DATA ; name="foo"`, "form-data", nameFoo}, + + {`form-data; key=value; blah="value";name="foo" `, + "form-data", + m("key", "value", "blah", "value", "name", "foo")}, + + {`foo; key=val1; key=the-key-appears-again-which-is-bogus`, + "", m()}, + + // From RFC 2231: + {`application/x-stuff; title*=us-ascii'en-us'This%20is%20%2A%2A%2Afun%2A%2A%2A`, + "application/x-stuff", + m("title", "This is ***fun***")}, + + {`message/external-body; access-type=URL; ` + + `URL*0="ftp://";` + + `URL*1="cs.utk.edu/pub/moore/bulk-mailer/bulk-mailer.tar"`, + "message/external-body", + m("access-type", "URL", + "URL", "ftp://cs.utk.edu/pub/moore/bulk-mailer/bulk-mailer.tar")}, + + {`application/x-stuff; ` + + `title*0*=us-ascii'en'This%20is%20even%20more%20; ` + + `title*1*=%2A%2A%2Afun%2A%2A%2A%20; ` + + `title*2="isn't it!"`, + "application/x-stuff", + m("title", "This is even more ***fun*** isn't it!")}, + + // Tests from http://greenbytes.de/tech/tc2231/ + // TODO(bradfitz): add the rest of the tests from that site. + {`attachment; filename="f\oo.html"`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename="\"quoting\" tested.html"`, + "attachment", + m("filename", `"quoting" tested.html`)}, + {`attachment; filename="Here's a semicolon;.html"`, + "attachment", + m("filename", "Here's a semicolon;.html")}, + {`attachment; foo="\"\\";filename="foo.html"`, + "attachment", + m("foo", "\"\\", "filename", "foo.html")}, + {`attachment; filename=foo.html`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename=foo.html ;`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename='foo.html'`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename="foo-%41.html"`, + "attachment", + m("filename", "foo-%41.html")}, + {`attachment; filename="foo-%\41.html"`, + "attachment", + m("filename", "foo-%41.html")}, + {`filename=foo.html`, + "", m()}, + {`x=y; filename=foo.html`, + "", m()}, + {`"foo; filename=bar;baz"; filename=qux`, + "", m()}, + {`inline; attachment; filename=foo.html`, + "", m()}, + {`attachment; filename="foo.html".txt`, + "", m()}, + {`attachment; filename="bar`, + "", m()}, + {`attachment; creation-date="Wed, 12 Feb 1997 16:29:51 -0500"`, + "attachment", + m("creation-date", "Wed, 12 Feb 1997 16:29:51 -0500")}, + {`foobar`, "foobar", m()}, + {`attachment; filename* =UTF-8''foo-%c3%a4.html`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*=UTF-8''A-%2541.html`, + "attachment", + m("filename", "A-%41.html")}, + {`attachment; filename*0="foo."; filename*1="html"`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename*0*=UTF-8''foo-%c3%a4; filename*1=".html"`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*0="foo"; filename*01="bar"`, + "attachment", + m("filename", "foo")}, + {`attachment; filename*0="foo"; filename*2="bar"`, + "attachment", + m("filename", "foo")}, + {`attachment; filename*1="foo"; filename*2="bar"`, + "attachment", m()}, + {`attachment; filename*1="bar"; filename*0="foo"`, + "attachment", + m("filename", "foobar")}, + {`attachment; filename="foo-ae.html"; filename*=UTF-8''foo-%c3%a4.html`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*=UTF-8''foo-%c3%a4.html; filename="foo-ae.html"`, + "attachment", + m("filename", "foo-ä.html")}, + + // Browsers also just send UTF-8 directly without RFC 2231, + // at least when the source page is served with UTF-8. + {`form-data; firstname="Брэд"; lastname="Фицпатрик"`, + "form-data", + m("firstname", "Брэд", "lastname", "Фицпатрик")}, } for _, test := range tests { - mt, params := ParseMediaType(test) - if mt != "form-data" { - t.Errorf("expected type form-data for %s, got [%s]", test, mt) + mt, params := ParseMediaType(test.in) + if g, e := mt, test.t; g != e { + t.Errorf("for input %q, expected type %q, got %q", + test.in, e, g) + continue + } + if len(params) == 0 && len(test.p) == 0 { continue } - if params["name"] != "foo" { - t.Errorf("expected name=foo for %s", test) + if !reflect.DeepEqual(params, test.p) { + t.Errorf("for input %q, wrong params.\n"+ + "expected: %#v\n"+ + " got: %#v", + test.in, test.p, params) } } } diff --git a/libgo/go/mime/multipart/formdata.go b/libgo/go/mime/multipart/formdata.go new file mode 100644 index 00000000000..5f328656590 --- /dev/null +++ b/libgo/go/mime/multipart/formdata.go @@ -0,0 +1,166 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package multipart + +import ( + "bytes" + "io" + "io/ioutil" + "net/textproto" + "os" +) + +// TODO(adg,bradfitz): find a way to unify the DoS-prevention strategy here +// with that of the http package's ParseForm. + +// ReadForm parses an entire multipart message whose parts have +// a Content-Disposition of "form-data". +// It stores up to maxMemory bytes of the file parts in memory +// and the remainder on disk in temporary files. +func (r *multiReader) ReadForm(maxMemory int64) (f *Form, err os.Error) { + form := &Form{make(map[string][]string), make(map[string][]*FileHeader)} + defer func() { + if err != nil { + form.RemoveAll() + } + }() + + maxValueBytes := int64(10 << 20) // 10 MB is a lot of text. + for { + p, err := r.NextPart() + if err == os.EOF { + break + } + if err != nil { + return nil, err + } + + name := p.FormName() + if name == "" { + continue + } + filename := p.FileName() + + var b bytes.Buffer + + if filename == "" { + // value, store as string in memory + n, err := io.Copyn(&b, p, maxValueBytes) + if err != nil && err != os.EOF { + return nil, err + } + maxValueBytes -= n + if maxValueBytes == 0 { + return nil, os.NewError("multipart: message too large") + } + form.Value[name] = append(form.Value[name], b.String()) + continue + } + + // file, store in memory or on disk + fh := &FileHeader{ + Filename: filename, + Header: p.Header, + } + n, err := io.Copyn(&b, p, maxMemory+1) + if err != nil && err != os.EOF { + return nil, err + } + if n > maxMemory { + // too big, write to disk and flush buffer + file, err := ioutil.TempFile("", "multipart-") + if err != nil { + return nil, err + } + defer file.Close() + _, err = io.Copy(file, io.MultiReader(&b, p)) + if err != nil { + os.Remove(file.Name()) + return nil, err + } + fh.tmpfile = file.Name() + } else { + fh.content = b.Bytes() + maxMemory -= n + } + form.File[name] = append(form.File[name], fh) + } + + return form, nil +} + +// Form is a parsed multipart form. +// Its File parts are stored either in memory or on disk, +// and are accessible via the *FileHeader's Open method. +// Its Value parts are stored as strings. +// Both are keyed by field name. +type Form struct { + Value map[string][]string + File map[string][]*FileHeader +} + +// RemoveAll removes any temporary files associated with a Form. +func (f *Form) RemoveAll() os.Error { + var err os.Error + for _, fhs := range f.File { + for _, fh := range fhs { + if fh.tmpfile != "" { + e := os.Remove(fh.tmpfile) + if e != nil && err == nil { + err = e + } + } + } + } + return err +} + +// A FileHeader describes a file part of a multipart request. +type FileHeader struct { + Filename string + Header textproto.MIMEHeader + + content []byte + tmpfile string +} + +// Open opens and returns the FileHeader's associated File. +func (fh *FileHeader) Open() (File, os.Error) { + if b := fh.content; b != nil { + r := io.NewSectionReader(sliceReaderAt(b), 0, int64(len(b))) + return sectionReadCloser{r}, nil + } + return os.Open(fh.tmpfile) +} + +// File is an interface to access the file part of a multipart message. +// Its contents may be either stored in memory or on disk. +// If stored on disk, the File's underlying concrete type will be an *os.File. +type File interface { + io.Reader + io.ReaderAt + io.Seeker + io.Closer +} + +// helper types to turn a []byte into a File + +type sectionReadCloser struct { + *io.SectionReader +} + +func (rc sectionReadCloser) Close() os.Error { + return nil +} + +type sliceReaderAt []byte + +func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, os.Error) { + if int(off) >= len(r) || off < 0 { + return 0, os.EINVAL + } + n := copy(b, r[int(off):]) + return n, nil +} diff --git a/libgo/go/mime/multipart/formdata_test.go b/libgo/go/mime/multipart/formdata_test.go new file mode 100644 index 00000000000..b56e2a430e0 --- /dev/null +++ b/libgo/go/mime/multipart/formdata_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package multipart + +import ( + "bytes" + "io" + "os" + "regexp" + "testing" +) + +func TestReadForm(t *testing.T) { + testBody := regexp.MustCompile("\n").ReplaceAllString(message, "\r\n") + b := bytes.NewBufferString(testBody) + r := NewReader(b, boundary) + f, err := r.ReadForm(25) + if err != nil { + t.Fatal("ReadForm:", err) + } + defer f.RemoveAll() + if g, e := f.Value["texta"][0], textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := f.Value["textb"][0], textbValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + fd := testFile(t, f.File["filea"][0], "filea.txt", fileaContents) + if _, ok := fd.(*os.File); ok { + t.Error("file is *os.File, should not be") + } + fd = testFile(t, f.File["fileb"][0], "fileb.txt", filebContents) + if _, ok := fd.(*os.File); !ok { + t.Error("file has unexpected underlying type %T", fd) + } +} + +func testFile(t *testing.T, fh *FileHeader, efn, econtent string) File { + if fh.Filename != efn { + t.Errorf("filename = %q, want %q", fh.Filename, efn) + } + f, err := fh.Open() + if err != nil { + t.Fatal("opening file:", err) + } + b := new(bytes.Buffer) + _, err = io.Copy(b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != econtent { + t.Errorf("contents = %q, want %q", g, econtent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/libgo/go/mime/multipart/multipart.go b/libgo/go/mime/multipart/multipart.go index 0a65a447db9..9affa112611 100644 --- a/libgo/go/mime/multipart/multipart.go +++ b/libgo/go/mime/multipart/multipart.go @@ -15,25 +15,32 @@ package multipart import ( "bufio" "bytes" + "fmt" "io" + "io/ioutil" "mime" "net/textproto" "os" "regexp" - "strings" ) var headerRegexp *regexp.Regexp = regexp.MustCompile("^([a-zA-Z0-9\\-]+): *([^\r\n]+)") +var emptyParams = make(map[string]string) + // Reader is an iterator over parts in a MIME multipart body. // Reader's underlying parser consumes its input as needed. Seeking // isn't supported. type Reader interface { - // NextPart returns the next part in the multipart, or (nil, - // nil) on EOF. An error is returned if the underlying reader - // reports errors, or on truncated or otherwise malformed - // input. + // NextPart returns the next part in the multipart or an error. + // When there are no more parts, the error os.EOF is returned. NextPart() (*Part, os.Error) + + // ReadForm parses an entire multipart message whose parts have + // a Content-Disposition of "form-data". + // It stores up to maxMemory bytes of the file parts in memory + // and the remainder on disk in temporary files. + ReadForm(maxMemory int64) (*Form, os.Error) } // A Part represents a single part in a multipart body. @@ -45,6 +52,9 @@ type Part struct { buffer *bytes.Buffer mr *multiReader + + disposition string + dispositionParams map[string]string } // FormName returns the name parameter if p has a Content-Disposition @@ -52,55 +62,67 @@ type Part struct { func (p *Part) FormName() string { // See http://tools.ietf.org/html/rfc2183 section 2 for EBNF // of Content-Disposition value format. - v := p.Header.Get("Content-Disposition") - if v == "" { - return "" + if p.dispositionParams == nil { + p.parseContentDisposition() } - d, params := mime.ParseMediaType(v) - if d != "form-data" { + if p.disposition != "form-data" { return "" } - return params["name"] + return p.dispositionParams["name"] +} + + +// FileName returns the filename parameter of the Part's +// Content-Disposition header. +func (p *Part) FileName() string { + if p.dispositionParams == nil { + p.parseContentDisposition() + } + return p.dispositionParams["filename"] +} + +func (p *Part) parseContentDisposition() { + v := p.Header.Get("Content-Disposition") + p.disposition, p.dispositionParams = mime.ParseMediaType(v) + if p.dispositionParams == nil { + p.dispositionParams = emptyParams + } } // NewReader creates a new multipart Reader reading from r using the // given MIME boundary. func NewReader(reader io.Reader, boundary string) Reader { + b := []byte("\r\n--" + boundary + "--") return &multiReader{ - boundary: boundary, - dashBoundary: "--" + boundary, - endLine: "--" + boundary + "--", - bufReader: bufio.NewReader(reader), + bufReader: bufio.NewReader(reader), + + nlDashBoundary: b[:len(b)-2], + dashBoundaryDash: b[2:], + dashBoundary: b[2 : len(b)-2], } } // Implementation .... -type devNullWriter bool - -func (*devNullWriter) Write(p []byte) (n int, err os.Error) { - return len(p), nil -} - -var devNull = devNullWriter(false) - -func newPart(mr *multiReader) (bp *Part, err os.Error) { - bp = new(Part) - bp.Header = make(map[string][]string) - bp.mr = mr - bp.buffer = new(bytes.Buffer) - if err = bp.populateHeaders(); err != nil { - bp = nil +func newPart(mr *multiReader) (*Part, os.Error) { + bp := &Part{ + Header: make(map[string][]string), + mr: mr, + buffer: new(bytes.Buffer), } - return + if err := bp.populateHeaders(); err != nil { + return nil, err + } + return bp, nil } func (bp *Part) populateHeaders() os.Error { for { - line, err := bp.mr.bufReader.ReadString('\n') + lineBytes, err := bp.mr.bufReader.ReadSlice('\n') if err != nil { return err } + line := string(lineBytes) if line == "\n" || line == "\r\n" { return nil } @@ -116,91 +138,63 @@ func (bp *Part) populateHeaders() os.Error { // Read reads the body of a part, after its headers and before the // next part (if any) begins. func (bp *Part) Read(p []byte) (n int, err os.Error) { - for { - if bp.buffer.Len() >= len(p) { - // Internal buffer of unconsumed data is large enough for - // the read request. No need to parse more at the moment. - break - } - if !bp.mr.ensureBufferedLine() { - return 0, io.ErrUnexpectedEOF - } - if bp.mr.bufferedLineIsBoundary() { - // Don't consume this line - break - } + if bp.buffer.Len() >= len(p) { + // Internal buffer of unconsumed data is large enough for + // the read request. No need to parse more at the moment. + return bp.buffer.Read(p) + } + peek, err := bp.mr.bufReader.Peek(4096) // TODO(bradfitz): add buffer size accessor + unexpectedEof := err == os.EOF + if err != nil && !unexpectedEof { + return 0, fmt.Errorf("multipart: Part Read: %v", err) + } + if peek == nil { + panic("nil peek buf") + } - // Write all of this line, except the final CRLF - s := *bp.mr.bufferedLine - if strings.HasSuffix(s, "\r\n") { - bp.mr.consumeLine() - if !bp.mr.ensureBufferedLine() { - return 0, io.ErrUnexpectedEOF - } - if bp.mr.bufferedLineIsBoundary() { - // The final \r\n isn't ours. It logically belongs - // to the boundary line which follows. - bp.buffer.WriteString(s[0 : len(s)-2]) - } else { - bp.buffer.WriteString(s) - } - break - } - if strings.HasSuffix(s, "\n") { - bp.buffer.WriteString(s) - bp.mr.consumeLine() - continue + // Search the peek buffer for "\r\n--boundary". If found, + // consume everything up to the boundary. If not, consume only + // as much of the peek buffer as cannot hold the boundary + // string. + nCopy := 0 + foundBoundary := false + if idx := bytes.Index(peek, bp.mr.nlDashBoundary); idx != -1 { + nCopy = idx + foundBoundary = true + } else if safeCount := len(peek) - len(bp.mr.nlDashBoundary); safeCount > 0 { + nCopy = safeCount + } else if unexpectedEof { + // If we've run out of peek buffer and the boundary + // wasn't found (and can't possibly fit), we must have + // hit the end of the file unexpectedly. + return 0, io.ErrUnexpectedEOF + } + if nCopy > 0 { + if _, err := io.Copyn(bp.buffer, bp.mr.bufReader, int64(nCopy)); err != nil { + return 0, err } - return 0, os.NewError("multipart parse error during Read; unexpected line: " + s) } - return bp.buffer.Read(p) + n, err = bp.buffer.Read(p) + if err == os.EOF && !foundBoundary { + // If the boundary hasn't been reached there's more to + // read, so don't pass through an EOF from the buffer + err = nil + } + return } func (bp *Part) Close() os.Error { - io.Copy(&devNull, bp) + io.Copy(ioutil.Discard, bp) return nil } type multiReader struct { - boundary string - dashBoundary string // --boundary - endLine string // --boundary-- - - bufferedLine *string + bufReader *bufio.Reader - bufReader *bufio.Reader currentPart *Part partsRead int -} -func (mr *multiReader) eof() bool { - return mr.bufferedLine == nil && - !mr.readLine() -} - -func (mr *multiReader) readLine() bool { - line, err := mr.bufReader.ReadString('\n') - if err != nil { - // TODO: care about err being EOF or not? - return false - } - mr.bufferedLine = &line - return true -} - -func (mr *multiReader) bufferedLineIsBoundary() bool { - return strings.HasPrefix(*mr.bufferedLine, mr.dashBoundary) -} - -func (mr *multiReader) ensureBufferedLine() bool { - if mr.bufferedLine == nil { - return mr.readLine() - } - return true -} - -func (mr *multiReader) consumeLine() { - mr.bufferedLine = nil + nlDashBoundary, dashBoundaryDash, dashBoundary []byte } func (mr *multiReader) NextPart() (*Part, os.Error) { @@ -208,13 +202,14 @@ func (mr *multiReader) NextPart() (*Part, os.Error) { mr.currentPart.Close() } + expectNewPart := false for { - if mr.eof() { - return nil, io.ErrUnexpectedEOF + line, err := mr.bufReader.ReadSlice('\n') + if err != nil { + return nil, fmt.Errorf("multipart: NextPart: %v", err) } - if isBoundaryDelimiterLine(*mr.bufferedLine, mr.dashBoundary) { - mr.consumeLine() + if mr.isBoundaryDelimiterLine(line) { mr.partsRead++ bp, err := newPart(mr) if err != nil { @@ -224,55 +219,66 @@ func (mr *multiReader) NextPart() (*Part, os.Error) { return bp, nil } - if hasPrefixThenNewline(*mr.bufferedLine, mr.endLine) { - mr.consumeLine() - // Expected EOF (no error) - return nil, nil + if hasPrefixThenNewline(line, mr.dashBoundaryDash) { + // Expected EOF + return nil, os.EOF + } + + if expectNewPart { + return nil, fmt.Errorf("multipart: expecting a new Part; got line %q", string(line)) } if mr.partsRead == 0 { // skip line - mr.consumeLine() continue } - return nil, os.NewError("Unexpected line in Next().") + if bytes.Equal(line, []byte("\r\n")) { + // Consume the "\r\n" separator between the + // body of the previous part and the boundary + // line we now expect will follow. (either a + // new part or the end boundary) + expectNewPart = true + continue + } + + return nil, fmt.Errorf("multipart: unexpected line in Next(): %q", line) } panic("unreachable") } -func isBoundaryDelimiterLine(line, dashPrefix string) bool { +func (mr *multiReader) isBoundaryDelimiterLine(line []byte) bool { // http://tools.ietf.org/html/rfc2046#section-5.1 // The boundary delimiter line is then defined as a line // consisting entirely of two hyphen characters ("-", // decimal value 45) followed by the boundary parameter // value from the Content-Type header field, optional linear // whitespace, and a terminating CRLF. - if !strings.HasPrefix(line, dashPrefix) { + if !bytes.HasPrefix(line, mr.dashBoundary) { return false } - if strings.HasSuffix(line, "\r\n") { - return onlyHorizontalWhitespace(line[len(dashPrefix) : len(line)-2]) + if bytes.HasSuffix(line, []byte("\r\n")) { + return onlyHorizontalWhitespace(line[len(mr.dashBoundary) : len(line)-2]) } // Violate the spec and also support newlines without the // carriage return... - if strings.HasSuffix(line, "\n") { - return onlyHorizontalWhitespace(line[len(dashPrefix) : len(line)-1]) + if bytes.HasSuffix(line, []byte("\n")) { + return onlyHorizontalWhitespace(line[len(mr.dashBoundary) : len(line)-1]) } return false } -func onlyHorizontalWhitespace(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] != ' ' && s[i] != '\t' { +func onlyHorizontalWhitespace(s []byte) bool { + for _, b := range s { + if b != ' ' && b != '\t' { return false } } return true } -func hasPrefixThenNewline(s, prefix string) bool { - return strings.HasPrefix(s, prefix) && - (len(s) == len(prefix)+1 && strings.HasSuffix(s, "\n") || - len(s) == len(prefix)+2 && strings.HasSuffix(s, "\r\n")) +func hasPrefixThenNewline(s, prefix []byte) bool { + return bytes.HasPrefix(s, prefix) && + (len(s) == len(prefix)+1 && s[len(s)-1] == '\n' || + len(s) == len(prefix)+2 && bytes.HasSuffix(s, []byte("\r\n"))) } diff --git a/libgo/go/mime/multipart/multipart_test.go b/libgo/go/mime/multipart/multipart_test.go index 1f3d32d7ed6..8222fbd8a4d 100644 --- a/libgo/go/mime/multipart/multipart_test.go +++ b/libgo/go/mime/multipart/multipart_test.go @@ -8,37 +8,37 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "json" - "regexp" + "os" "strings" "testing" ) func TestHorizontalWhitespace(t *testing.T) { - if !onlyHorizontalWhitespace(" \t") { + if !onlyHorizontalWhitespace([]byte(" \t")) { t.Error("expected pass") } - if onlyHorizontalWhitespace("foo bar") { + if onlyHorizontalWhitespace([]byte("foo bar")) { t.Error("expected failure") } } func TestBoundaryLine(t *testing.T) { - boundary := "myBoundary" - prefix := "--" + boundary - if !isBoundaryDelimiterLine("--myBoundary\r\n", prefix) { + mr := NewReader(strings.NewReader(""), "myBoundary").(*multiReader) + if !mr.isBoundaryDelimiterLine([]byte("--myBoundary\r\n")) { t.Error("expected") } - if !isBoundaryDelimiterLine("--myBoundary \r\n", prefix) { + if !mr.isBoundaryDelimiterLine([]byte("--myBoundary \r\n")) { t.Error("expected") } - if !isBoundaryDelimiterLine("--myBoundary \n", prefix) { + if !mr.isBoundaryDelimiterLine([]byte("--myBoundary \n")) { t.Error("expected") } - if isBoundaryDelimiterLine("--myBoundary bogus \n", prefix) { + if mr.isBoundaryDelimiterLine([]byte("--myBoundary bogus \n")) { t.Error("expected fail") } - if isBoundaryDelimiterLine("--myBoundary bogus--", prefix) { + if mr.isBoundaryDelimiterLine([]byte("--myBoundary bogus--")) { t.Error("expected fail") } } @@ -56,29 +56,32 @@ func expectEq(t *testing.T, expected, actual, what string) { what, escapeString(actual), len(actual), escapeString(expected), len(expected)) } -func TestFormName(t *testing.T) { - p := new(Part) - p.Header = make(map[string][]string) - tests := [...][2]string{ - {`form-data; name="foo"`, "foo"}, - {` form-data ; name=foo`, "foo"}, - {`FORM-DATA;name="foo"`, "foo"}, - {` FORM-DATA ; name="foo"`, "foo"}, - {` FORM-DATA ; name="foo"`, "foo"}, - {` FORM-DATA ; name=foo`, "foo"}, - {` FORM-DATA ; filename="foo.txt"; name=foo; baz=quux`, "foo"}, +func TestNameAccessors(t *testing.T) { + tests := [...][3]string{ + {`form-data; name="foo"`, "foo", ""}, + {` form-data ; name=foo`, "foo", ""}, + {`FORM-DATA;name="foo"`, "foo", ""}, + {` FORM-DATA ; name="foo"`, "foo", ""}, + {` FORM-DATA ; name="foo"`, "foo", ""}, + {` FORM-DATA ; name=foo`, "foo", ""}, + {` FORM-DATA ; filename="foo.txt"; name=foo; baz=quux`, "foo", "foo.txt"}, + {` not-form-data ; filename="bar.txt"; name=foo; baz=quux`, "", "bar.txt"}, } - for _, test := range tests { + for i, test := range tests { + p := &Part{Header: make(map[string][]string)} p.Header.Set("Content-Disposition", test[0]) - expected := test[1] - actual := p.FormName() - if actual != expected { - t.Errorf("expected \"%s\"; got: \"%s\"", expected, actual) + if g, e := p.FormName(), test[1]; g != e { + t.Errorf("test %d: FormName() = %q; want %q", i, g, e) + } + if g, e := p.FileName(), test[2]; g != e { + t.Errorf("test %d: FileName() = %q; want %q", i, g, e) } } } -func TestMultipart(t *testing.T) { +var longLine = strings.Repeat("\n\n\r\r\r\n\r\000", (1<<20)/8) + +func testMultipartBody() string { testBody := ` This is a multi-part message. This line is ignored. --MyBoundary @@ -89,6 +92,10 @@ foo-bar: baz My value The end. --MyBoundary +name: bigsection + +[longline] +--MyBoundary Header1: value1b HEADER2: value2b foo-bar: bazb @@ -101,11 +108,26 @@ Line 3 ends in a newline, but just one. never read data --MyBoundary-- + + +useless trailer ` - testBody = regexp.MustCompile("\n").ReplaceAllString(testBody, "\r\n") - bodyReader := strings.NewReader(testBody) + testBody = strings.Replace(testBody, "\n", "\r\n", -1) + return strings.Replace(testBody, "[longline]", longLine, 1) +} - reader := NewReader(bodyReader, "MyBoundary") +func TestMultipart(t *testing.T) { + bodyReader := strings.NewReader(testMultipartBody()) + testMultipart(t, bodyReader) +} + +func TestMultipartSlowInput(t *testing.T) { + bodyReader := strings.NewReader(testMultipartBody()) + testMultipart(t, &slowReader{bodyReader}) +} + +func testMultipart(t *testing.T, r io.Reader) { + reader := NewReader(r, "MyBoundary") buf := new(bytes.Buffer) // Part1 @@ -124,38 +146,64 @@ never read data t.Error("Expected Foo-Bar: baz") } buf.Reset() - io.Copy(buf, part) + if _, err := io.Copy(buf, part); err != nil { + t.Errorf("part 1 copy: %v", err) + } expectEq(t, "My value\r\nThe end.", buf.String(), "Value of first part") // Part2 part, err = reader.NextPart() + if err != nil { + t.Fatalf("Expected part2; got: %v", err) + return + } + if e, g := "bigsection", part.Header.Get("name"); e != g { + t.Errorf("part2's name header: expected %q, got %q", e, g) + } + buf.Reset() + if _, err := io.Copy(buf, part); err != nil { + t.Errorf("part 2 copy: %v", err) + } + s := buf.String() + if len(s) != len(longLine) { + t.Errorf("part2 body expected long line of length %d; got length %d", + len(longLine), len(s)) + } + if s != longLine { + t.Errorf("part2 long body didn't match") + } + + // Part3 + part, err = reader.NextPart() if part == nil || err != nil { - t.Error("Expected part2") + t.Error("Expected part3") return } if part.Header.Get("foo-bar") != "bazb" { t.Error("Expected foo-bar: bazb") } buf.Reset() - io.Copy(buf, part) + if _, err := io.Copy(buf, part); err != nil { + t.Errorf("part 3 copy: %v", err) + } expectEq(t, "Line 1\r\nLine 2\r\nLine 3 ends in a newline, but just one.\r\n", - buf.String(), "Value of second part") + buf.String(), "body of part 3") - // Part3 + // Part4 part, err = reader.NextPart() if part == nil || err != nil { - t.Error("Expected part3 without errors") + t.Error("Expected part 4 without errors") return } - // Non-existent part4 + // Non-existent part5 part, err = reader.NextPart() if part != nil { - t.Error("Didn't expect a third part.") + t.Error("Didn't expect a fifth part.") } - if err != nil { - t.Errorf("Unexpected error getting third part: %v", err) + if err != os.EOF { + t.Errorf("On fifth part expected os.EOF; got %v", err) } } @@ -199,9 +247,73 @@ func TestVariousTextLineEndings(t *testing.T) { if part != nil { t.Errorf("Unexpected part in test %d", testNum) } - if err != nil { - t.Errorf("Unexpected error in test %d: %v", testNum, err) + if err != os.EOF { + t.Errorf("On test %d expected os.EOF; got %v", testNum, err) } } } + +type maliciousReader struct { + t *testing.T + n int +} + +const maxReadThreshold = 1 << 20 + +func (mr *maliciousReader) Read(b []byte) (n int, err os.Error) { + mr.n += len(b) + if mr.n >= maxReadThreshold { + mr.t.Fatal("too much was read") + return 0, os.EOF + } + return len(b), nil +} + +func TestLineLimit(t *testing.T) { + mr := &maliciousReader{t: t} + r := NewReader(mr, "fooBoundary") + part, err := r.NextPart() + if part != nil { + t.Errorf("unexpected part read") + } + if err == nil { + t.Errorf("expected an error") + } + if mr.n >= maxReadThreshold { + t.Errorf("expected to read < %d bytes; read %d", maxReadThreshold, mr.n) + } +} + +func TestMultipartTruncated(t *testing.T) { + testBody := ` +This is a multi-part message. This line is ignored. +--MyBoundary +foo-bar: baz + +Oh no, premature EOF! +` + body := strings.Replace(testBody, "\n", "\r\n", -1) + bodyReader := strings.NewReader(body) + r := NewReader(bodyReader, "MyBoundary") + + part, err := r.NextPart() + if err != nil { + t.Fatalf("didn't get a part") + } + _, err = io.Copy(ioutil.Discard, part) + if err != io.ErrUnexpectedEOF { + t.Fatalf("expected error io.ErrUnexpectedEOF; got %v", err) + } +} + +type slowReader struct { + r io.Reader +} + +func (s *slowReader) Read(p []byte) (int, os.Error) { + if len(p) == 0 { + return s.r.Read(p) + } + return s.r.Read(p[:1]) +} diff --git a/libgo/go/mime/type.go b/libgo/go/mime/type.go index 6fe0ed5fd5e..8c43b81b0c5 100644 --- a/libgo/go/mime/type.go +++ b/libgo/go/mime/type.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The mime package implements parts of the MIME spec. +// Package mime implements parts of the MIME spec. package mime import ( diff --git a/libgo/go/net/cgo_bsd.go b/libgo/go/net/cgo_bsd.go new file mode 100644 index 00000000000..d9fef45de0a --- /dev/null +++ b/libgo/go/net/cgo_bsd.go @@ -0,0 +1,15 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include +*/ + +import "syscall" + +func cgoAddrInfoMask() C.int { + return syscall.AI_MASK +} diff --git a/libgo/go/net/cgo_linux.go b/libgo/go/net/cgo_linux.go new file mode 100644 index 00000000000..482435221e0 --- /dev/null +++ b/libgo/go/net/cgo_linux.go @@ -0,0 +1,15 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include +*/ + +import "syscall" + +func cgoAddrInfoMask() int { + return syscall.AI_CANONNAME | syscall.AI_V4MAPPED | syscall.AI_ALL +} diff --git a/libgo/go/net/cgo_stub.go b/libgo/go/net/cgo_stub.go index e28f6622e93..c6277cb657c 100644 --- a/libgo/go/net/cgo_stub.go +++ b/libgo/go/net/cgo_stub.go @@ -19,3 +19,7 @@ func cgoLookupPort(network, service string) (port int, err os.Error, completed b func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { return nil, nil, false } + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + return "", nil, false +} diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go new file mode 100644 index 00000000000..b8090181293 --- /dev/null +++ b/libgo/go/net/cgo_unix.go @@ -0,0 +1,149 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include +#include +#include +#include +#include +#include +#include +*/ + +import ( + "os" + "syscall" + "unsafe" +) + +func libc_getaddrinfo(node *byte, service *byte, hints *syscall.Addrinfo, res **syscall.Addrinfo) int __asm__ ("getaddrinfo") +func libc_freeaddrinfo(res *syscall.Addrinfo) __asm__ ("freeaddrinfo") +func libc_gai_strerror(errcode int) *byte __asm__ ("gai_strerror") + +func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { + ip, err, completed := cgoLookupIP(name) + for _, p := range ip { + addrs = append(addrs, p.String()) + } + return +} + +func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) { + var res *syscall.Addrinfo + var hints syscall.Addrinfo + + switch net { + case "": + // no hints + case "tcp", "tcp4", "tcp6": + hints.Ai_socktype = syscall.SOCK_STREAM + hints.Ai_protocol = syscall.IPPROTO_TCP + case "udp", "udp4", "udp6": + hints.Ai_socktype = syscall.SOCK_DGRAM + hints.Ai_protocol = syscall.IPPROTO_UDP + default: + return 0, UnknownNetworkError(net), true + } + if len(net) >= 4 { + switch net[3] { + case '4': + hints.Ai_family = syscall.AF_INET + case '6': + hints.Ai_family = syscall.AF_INET6 + } + } + + s := syscall.StringBytePtr(service) + if libc_getaddrinfo(nil, s, &hints, &res) == 0 { + defer libc_freeaddrinfo(res) + for r := res; r != nil; r = r.Ai_next { + switch r.Ai_family { + default: + continue + case syscall.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.Ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + case syscall.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.Ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + } + } + } + return 0, &AddrError{"unknown port", net + "/" + service}, true +} + +func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, completed bool) { + var res *syscall.Addrinfo + var hints syscall.Addrinfo + + // NOTE(rsc): In theory there are approximately balanced + // arguments for and against including AI_ADDRCONFIG + // in the flags (it includes IPv4 results only on IPv4 systems, + // and similarly for IPv6), but in practice setting it causes + // getaddrinfo to return the wrong canonical name on Linux. + // So definitely leave it out. + hints.Ai_flags = int32((syscall.AI_ALL | syscall.AI_V4MAPPED | syscall.AI_CANONNAME) & cgoAddrInfoMask()) + + h := syscall.StringBytePtr(name) + gerrno := libc_getaddrinfo(h, nil, &hints, &res) + if gerrno != 0 { + var str string + if gerrno == syscall.EAI_NONAME { + str = noSuchHost + } else if gerrno == syscall.EAI_SYSTEM { + str = syscall.Errstr(syscall.GetErrno()) + } else { + str = syscall.BytePtrToString(libc_gai_strerror(gerrno)) + } + return nil, "", &DNSError{Error: str, Name: name}, true + } + defer libc_freeaddrinfo(res) + if res != nil { + cname = syscall.BytePtrToString((*byte)(unsafe.Pointer(res.Ai_canonname))) + if cname == "" { + cname = name + } + if len(cname) > 0 && cname[len(cname)-1] != '.' { + cname += "." + } + } + for r := res; r != nil; r = r.Ai_next { + // Everything comes back twice, once for UDP and once for TCP. + if r.Ai_socktype != syscall.SOCK_STREAM { + continue + } + switch r.Ai_family { + default: + continue + case syscall.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.Ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + case syscall.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.Ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + } + } + return addrs, cname, nil, true +} + +func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { + addrs, _, err, completed = cgoLookupIPCNAME(name) + return +} + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + _, cname, err, completed = cgoLookupIPCNAME(name) + return +} + +func copyIP(x IP) IP { + y := make(IP, len(x)) + copy(y, x) + return y +} diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index 66cb09b19bb..16896b4269b 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -30,7 +30,7 @@ func Dial(net, addr string) (c Conn, err os.Error) { switch net { case "tcp", "tcp4", "tcp6": var ra *TCPAddr - if ra, err = ResolveTCPAddr(raddr); err != nil { + if ra, err = ResolveTCPAddr(net, raddr); err != nil { goto Error } c, err := DialTCP(net, nil, ra) @@ -40,7 +40,7 @@ func Dial(net, addr string) (c Conn, err os.Error) { return c, nil case "udp", "udp4", "udp6": var ra *UDPAddr - if ra, err = ResolveUDPAddr(raddr); err != nil { + if ra, err = ResolveUDPAddr(net, raddr); err != nil { goto Error } c, err := DialUDP(net, nil, ra) @@ -83,7 +83,7 @@ func Listen(net, laddr string) (l Listener, err os.Error) { case "tcp", "tcp4", "tcp6": var la *TCPAddr if laddr != "" { - if la, err = ResolveTCPAddr(laddr); err != nil { + if la, err = ResolveTCPAddr(net, laddr); err != nil { return nil, err } } @@ -116,7 +116,7 @@ func ListenPacket(net, laddr string) (c PacketConn, err os.Error) { case "udp", "udp4", "udp6": var la *UDPAddr if laddr != "" { - if la, err = ResolveUDPAddr(laddr); err != nil { + if la, err = ResolveUDPAddr(net, laddr); err != nil { return nil, err } } diff --git a/libgo/go/net/dialgoogle_test.go b/libgo/go/net/dialgoogle_test.go index 9a9c02ebd71..e90c4f3f894 100644 --- a/libgo/go/net/dialgoogle_test.go +++ b/libgo/go/net/dialgoogle_test.go @@ -41,7 +41,19 @@ func doDial(t *testing.T, network, addr string) { fd.Close() } -var googleaddrs = []string{ +func TestLookupCNAME(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } + cname, err := LookupCNAME("www.google.com") + if !strings.HasSuffix(cname, ".l.google.com.") || err != nil { + t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "*.l.google.com.", nil`, cname, err) + } +} + +var googleaddrsipv4 = []string{ "%d.%d.%d.%d:80", "www.google.com:80", "%d.%d.%d.%d:http", @@ -52,42 +64,40 @@ var googleaddrs = []string{ "[0:0:0:0:0000:ffff:%d.%d.%d.%d]:80", "[0:0:0:0:000000:ffff:%d.%d.%d.%d]:80", "[0:0:0:0:0:ffff::%d.%d.%d.%d]:80", - "[2001:4860:0:2001::68]:80", // ipv6.google.com; removed if ipv6 flag not set } -func TestLookupCNAME(t *testing.T) { - cname, err := LookupCNAME("www.google.com") - if cname != "www.l.google.com." || err != nil { - t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "www.l.google.com.", nil`, cname, err) - } -} - -func TestDialGoogle(t *testing.T) { - // If no ipv6 tunnel, don't try the last address. - if !*ipv6 { - googleaddrs[len(googleaddrs)-1] = "" +func TestDialGoogleIPv4(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return } - // Insert an actual IP address for google.com + // Insert an actual IPv4 address for google.com // into the table. - addrs, err := LookupIP("www.google.com") if err != nil { t.Fatalf("lookup www.google.com: %v", err) } - if len(addrs) == 0 { - t.Fatalf("no addresses for www.google.com") + var ip IP + for _, addr := range addrs { + if x := addr.To4(); x != nil { + ip = x + break + } + } + if ip == nil { + t.Fatalf("no IPv4 addresses for www.google.com") } - ip := addrs[0].To4() - for i, s := range googleaddrs { + for i, s := range googleaddrsipv4 { if strings.Contains(s, "%") { - googleaddrs[i] = fmt.Sprintf(s, ip[0], ip[1], ip[2], ip[3]) + googleaddrsipv4[i] = fmt.Sprintf(s, ip[0], ip[1], ip[2], ip[3]) } } - for i := 0; i < len(googleaddrs); i++ { - addr := googleaddrs[i] + for i := 0; i < len(googleaddrsipv4); i++ { + addr := googleaddrsipv4[i] if addr == "" { continue } @@ -95,20 +105,67 @@ func TestDialGoogle(t *testing.T) { doDial(t, "tcp", addr) if addr[0] != '[' { doDial(t, "tcp4", addr) - if !preferIPv4 { // make sure preferIPv4 flag works. preferIPv4 = true syscall.SocketDisableIPv6 = true + doDial(t, "tcp", addr) doDial(t, "tcp4", addr) syscall.SocketDisableIPv6 = false preferIPv4 = false } } + } +} + +var googleaddrsipv6 = []string{ + "[%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x]:80", + "ipv6.google.com:80", + "[%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x]:http", + "ipv6.google.com:http", +} - // Only run tcp6 if the kernel will take it. - if kernelSupportsIPv6() { - doDial(t, "tcp6", addr) +func TestDialGoogleIPv6(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } + // Only run tcp6 if the kernel will take it. + if !*ipv6 || !kernelSupportsIPv6() { + return + } + + // Insert an actual IPv6 address for ipv6.google.com + // into the table. + addrs, err := LookupIP("ipv6.google.com") + if err != nil { + t.Fatalf("lookup ipv6.google.com: %v", err) + } + var ip IP + for _, addr := range addrs { + if x := addr.To16(); x != nil { + ip = x + break + } + } + if ip == nil { + t.Fatalf("no IPv6 addresses for ipv6.google.com") + } + + for i, s := range googleaddrsipv6 { + if strings.Contains(s, "%") { + googleaddrsipv6[i] = fmt.Sprintf(s, ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]) + } + } + + for i := 0; i < len(googleaddrsipv6); i++ { + addr := googleaddrsipv6[i] + if addr == "" { + continue } + t.Logf("-- %s --", addr) + doDial(t, "tcp", addr) + doDial(t, "tcp6", addr) } } diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go index 32cea6125eb..3466003fab8 100644 --- a/libgo/go/net/dnsclient.go +++ b/libgo/go/net/dnsclient.go @@ -21,6 +21,7 @@ import ( "rand" "sync" "time" + "sort" ) // DNSError represents a DNS lookup error. @@ -120,15 +121,19 @@ func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs Cname: for cnameloop := 0; cnameloop < 10; cnameloop++ { addrs = addrs[0:0] - for i := 0; i < len(dns.answer); i++ { - rr := dns.answer[i] + for _, rr := range dns.answer { + if _, justHeader := rr.(*dnsRR_Header); justHeader { + // Corrupt record: we only have a + // header. That header might say it's + // of type qtype, but we don't + // actually have it. Skip. + continue + } h := rr.Header() if h.Class == dnsClassINET && h.Name == name { switch h.Rrtype { case qtype: - n := len(addrs) - addrs = addrs[0 : n+1] - addrs[n] = rr + addrs = append(addrs, rr) case dnsTypeCNAME: // redirect to cname name = rr.(*dnsRR_CNAME).Cname @@ -180,8 +185,7 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs func convertRR_A(records []dnsRR) []IP { addrs := make([]IP, len(records)) - for i := 0; i < len(records); i++ { - rr := records[i] + for i, rr := range records { a := rr.(*dnsRR_A).A addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)) } @@ -190,8 +194,7 @@ func convertRR_A(records []dnsRR) []IP { func convertRR_AAAA(records []dnsRR) []IP { addrs := make([]IP, len(records)) - for i := 0; i < len(records); i++ { - rr := records[i] + for i, rr := range records { a := make(IP, 16) copy(a, rr.(*dnsRR_AAAA).AAAA[:]) addrs[i] = a @@ -306,17 +309,22 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro } // goLookupHost is the native Go implementation of LookupHost. +// Used only if cgoLookupHost refuses to handle the request +// (that is, only if cgoLookupHost is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. func goLookupHost(name string) (addrs []string, err os.Error) { - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr - return - } // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) if len(addrs) > 0 { return } + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } ips, err := goLookupIP(name) if err != nil { return @@ -329,6 +337,11 @@ func goLookupHost(name string) (addrs []string, err os.Error) { } // goLookupIP is the native Go implementation of LookupIP. +// Used only if cgoLookupIP refuses to handle the request +// (that is, only if cgoLookupIP is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. func goLookupIP(name string) (addrs []IP, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { @@ -357,11 +370,13 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { return } -// LookupCNAME returns the canonical DNS host for the given name. -// Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving -// the canonical name as part of the lookup. -func LookupCNAME(name string) (cname string, err os.Error) { +// goLookupCNAME is the native Go implementation of LookupCNAME. +// Used only if cgoLookupCNAME refuses to handle the request +// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func goLookupCNAME(name string) (cname string, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr @@ -371,9 +386,7 @@ func LookupCNAME(name string) (cname string, err os.Error) { if err != nil { return } - if len(rr) >= 0 { - cname = rr[0].(*dnsRR_CNAME).Cname - } + cname = rr[0].(*dnsRR_CNAME).Cname return } @@ -397,8 +410,8 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os. return } addrs = make([]*SRV, len(records)) - for i := 0; i < len(records); i++ { - r := records[i].(*dnsRR_SRV) + for i, rr := range records { + r := rr.(*dnsRR_SRV) addrs[i] = &SRV{r.Target, r.Port, r.Priority, r.Weight} } return @@ -410,18 +423,32 @@ type MX struct { Pref uint16 } -// LookupMX returns the DNS MX records associated with name. -func LookupMX(name string) (entries []*MX, err os.Error) { - var records []dnsRR - _, records, err = lookup(name, dnsTypeMX) +// byPref implements sort.Interface to sort MX records by preference +type byPref []*MX + +func (s byPref) Len() int { return len(s) } + +func (s byPref) Less(i, j int) bool { return s[i].Pref < s[j].Pref } + +func (s byPref) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// LookupMX returns the DNS MX records for the given domain name sorted by preference. +func LookupMX(name string) (mx []*MX, err os.Error) { + _, rr, err := lookup(name, dnsTypeMX) if err != nil { return } - entries = make([]*MX, len(records)) - for i := range records { - r := records[i].(*dnsRR_MX) - entries[i] = &MX{r.Mx, r.Pref} + mx = make([]*MX, len(rr)) + for i := range rr { + r := rr[i].(*dnsRR_MX) + mx[i] = &MX{r.Mx, r.Pref} + } + // Shuffle the records to match RFC 5321 when sorted + for i := range mx { + j := rand.Intn(i + 1) + mx[i], mx[j] = mx[j], mx[i] } + sort.Sort(byPref(mx)) return } diff --git a/libgo/go/net/dnsmsg.go b/libgo/go/net/dnsmsg.go index 5209c1a06a5..731efe26a44 100644 --- a/libgo/go/net/dnsmsg.go +++ b/libgo/go/net/dnsmsg.go @@ -390,52 +390,48 @@ Loop: // TODO(rsc): Move into generic library? // Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, // [n]byte, and other (often anonymous) structs. -func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { +func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { for i := 0; i < val.NumField(); i++ { - f := val.Type().(*reflect.StructType).Field(i) - switch fv := val.Field(i).(type) { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { default: BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) return len(msg), false - case *reflect.StructValue: + case reflect.Struct: off, ok = packStructValue(fv, msg, off) - case *reflect.UintValue: - i := fv.Get() - switch fv.Type().Kind() { - default: - goto BadType - case reflect.Uint16: - if off+2 > len(msg) { - return len(msg), false - } - msg[off] = byte(i >> 8) - msg[off+1] = byte(i) - off += 2 - case reflect.Uint32: - if off+4 > len(msg) { - return len(msg), false - } - msg[off] = byte(i >> 24) - msg[off+1] = byte(i >> 16) - msg[off+2] = byte(i >> 8) - msg[off+3] = byte(i) - off += 4 + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false } - case *reflect.ArrayValue: - if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 { + i := fv.Uint() + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { goto BadType } n := fv.Len() if off+n > len(msg) { return len(msg), false } - reflect.Copy(reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue), fv) + reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv) off += n - case *reflect.StringValue: + case reflect.String: // There are multiple string encodings. // The tag distinguishes ordinary strings from domain names. - s := fv.Get() + s := fv.String() switch f.Tag { default: fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) @@ -459,8 +455,8 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o return off, true } -func structValue(any interface{}) *reflect.StructValue { - return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue) +func structValue(any interface{}) reflect.Value { + return reflect.ValueOf(any).Elem() } func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { @@ -471,46 +467,41 @@ func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { // TODO(rsc): Move into generic library? // Unpack a reflect.StructValue from msg. // Same restrictions as packStructValue. -func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { +func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { for i := 0; i < val.NumField(); i++ { - f := val.Type().(*reflect.StructType).Field(i) - switch fv := val.Field(i).(type) { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { default: BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) return len(msg), false - case *reflect.StructValue: + case reflect.Struct: off, ok = unpackStructValue(fv, msg, off) - case *reflect.UintValue: - switch fv.Type().Kind() { - default: - goto BadType - case reflect.Uint16: - if off+2 > len(msg) { - return len(msg), false - } - i := uint16(msg[off])<<8 | uint16(msg[off+1]) - fv.Set(uint64(i)) - off += 2 - case reflect.Uint32: - if off+4 > len(msg) { - return len(msg), false - } - i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) - fv.Set(uint64(i)) - off += 4 + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false } - case *reflect.ArrayValue: - if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 { + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.SetUint(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.SetUint(uint64(i)) + off += 4 + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { goto BadType } n := fv.Len() if off+n > len(msg) { return len(msg), false } - reflect.Copy(fv, reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue)) + reflect.Copy(fv, reflect.ValueOf(msg[off:off+n])) off += n - case *reflect.StringValue: + case reflect.String: var s string switch f.Tag { default: @@ -534,7 +525,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, off += n s = string(b) } - fv.Set(s) + fv.SetString(s) } } return off, true @@ -550,23 +541,23 @@ func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { // but does look for an "ipv4" tag on uint32 variables // and the "ipv6" tag on array variables, // printing them as IP addresses. -func printStructValue(val *reflect.StructValue) string { +func printStructValue(val reflect.Value) string { s := "{" for i := 0; i < val.NumField(); i++ { if i > 0 { s += ", " } - f := val.Type().(*reflect.StructType).Field(i) + f := val.Type().Field(i) if !f.Anonymous { s += f.Name + "=" } fval := val.Field(i) - if fv, ok := fval.(*reflect.StructValue); ok { + if fv := fval; fv.Kind() == reflect.Struct { s += printStructValue(fv) - } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" { - i := fv.Get() + } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == "ipv4" { + i := fv.Uint() s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() - } else if fv, ok := fval.(*reflect.ArrayValue); ok && f.Tag == "ipv6" { + } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == "ipv6" { i := fv.Interface().([]byte) s += IP(i).String() } else { @@ -724,24 +715,35 @@ func (dns *dnsMsg) Unpack(msg []byte) bool { // Arrays. dns.question = make([]dnsQuestion, dh.Qdcount) - dns.answer = make([]dnsRR, dh.Ancount) - dns.ns = make([]dnsRR, dh.Nscount) - dns.extra = make([]dnsRR, dh.Arcount) + dns.answer = make([]dnsRR, 0, dh.Ancount) + dns.ns = make([]dnsRR, 0, dh.Nscount) + dns.extra = make([]dnsRR, 0, dh.Arcount) + + var rec dnsRR for i := 0; i < len(dns.question); i++ { off, ok = unpackStruct(&dns.question[i], msg, off) } - for i := 0; i < len(dns.answer); i++ { - dns.answer[i], off, ok = unpackRR(msg, off) - } - for i := 0; i < len(dns.ns); i++ { - dns.ns[i], off, ok = unpackRR(msg, off) + for i := 0; i < int(dh.Ancount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.answer = append(dns.answer, rec) } - for i := 0; i < len(dns.extra); i++ { - dns.extra[i], off, ok = unpackRR(msg, off) + for i := 0; i < int(dh.Nscount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.ns = append(dns.ns, rec) } - if !ok { - return false + for i := 0; i < int(dh.Arcount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.extra = append(dns.extra, rec) } // if off != len(msg) { // println("extra bytes in dns packet", off, "<", len(msg)); diff --git a/libgo/go/net/dnsmsg_test.go b/libgo/go/net/dnsmsg_test.go new file mode 100644 index 00000000000..20c9f02b0b4 --- /dev/null +++ b/libgo/go/net/dnsmsg_test.go @@ -0,0 +1,107 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "encoding/hex" + "runtime" + "testing" +) + +func TestDNSParseSRVReply(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + data, err := hex.DecodeString(dnsSRVReply) + if err != nil { + t.Fatal(err) + } + msg := new(dnsMsg) + ok := msg.Unpack(data) + if !ok { + t.Fatalf("unpacking packet failed") + } + if g, e := len(msg.answer), 5; g != e { + t.Errorf("len(msg.answer) = %d; want %d", g, e) + } + for idx, rr := range msg.answer { + if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e { + t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e) + } + if _, ok := rr.(*dnsRR_SRV); !ok { + t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr) + } + } + _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV)) + if err != nil { + t.Fatalf("answer: %v", err) + } + if g, e := len(addrs), 5; g != e { + t.Errorf("len(addrs) = %d; want %d", g, e) + t.Logf("addrs = %#v", addrs) + } +} + +func TestDNSParseCorruptSRVReply(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + data, err := hex.DecodeString(dnsSRVCorruptReply) + if err != nil { + t.Fatal(err) + } + msg := new(dnsMsg) + ok := msg.Unpack(data) + if !ok { + t.Fatalf("unpacking packet failed") + } + if g, e := len(msg.answer), 5; g != e { + t.Errorf("len(msg.answer) = %d; want %d", g, e) + } + for idx, rr := range msg.answer { + if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e { + t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e) + } + if idx == 4 { + if _, ok := rr.(*dnsRR_Header); !ok { + t.Errorf("answer[%d] = %T; want *dnsRR_Header", idx, rr) + } + } else { + if _, ok := rr.(*dnsRR_SRV); !ok { + t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr) + } + } + } + _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV)) + if err != nil { + t.Fatalf("answer: %v", err) + } + if g, e := len(addrs), 4; g != e { + t.Errorf("len(addrs) = %d; want %d", g, e) + t.Logf("addrs = %#v", addrs) + } +} + +// Valid DNS SRV reply +const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" + + "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" + + "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" + + "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" + + "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" + + "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" + + "72016c06676f6f676c6503636f6d00c00c002100010000012c00210014000014950c78" + + "6d70702d73657276657231016c06676f6f676c6503636f6d00" + +// Corrupt DNS SRV reply, with its final RR having a bogus length +// (perhaps it was truncated, or it's malicious) The mutation is the +// capital "FF" below, instead of the proper "21". +const dnsSRVCorruptReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" + + "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" + + "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" + + "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" + + "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" + + "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" + + "72016c06676f6f676c6503636f6d00c00c002100010000012c00FF0014000014950c78" + + "6d70702d73657276657231016c06676f6f676c6503636f6d00" diff --git a/libgo/go/net/hosts_test.go b/libgo/go/net/hosts_test.go index 470e35f7863..e5793eef2c7 100644 --- a/libgo/go/net/hosts_test.go +++ b/libgo/go/net/hosts_test.go @@ -5,6 +5,7 @@ package net import ( + "sort" "testing" ) @@ -51,3 +52,17 @@ func TestLookupStaticHost(t *testing.T) { } hostsPath = p } + +func TestLookupHost(t *testing.T) { + // Can't depend on this to return anything in particular, + // but if it does return something, make sure it doesn't + // duplicate addresses (a common bug due to the way + // getaddrinfo works). + addrs, _ := LookupHost("localhost") + sort.SortStrings(addrs) + for i := 0; i+1 < len(addrs); i++ { + if addrs[i] == addrs[i+1] { + t.Fatalf("LookupHost(\"localhost\") = %v, has duplicate addresses", addrs) + } + } +} diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index 12bb6f351a1..61b2c687e2f 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -75,7 +75,8 @@ var ( // Well-known IPv6 addresses var ( - IPzero = make(IP, IPv6len) // all zeros + IPzero = make(IP, IPv6len) // all zeros + IPv6loopback = IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) ) // Is p all zeros? @@ -436,7 +437,7 @@ func parseIPv6(s string) IP { } // Otherwise must be followed by colon and more. - if s[i] != ':' && i+1 == len(s) { + if s[i] != ':' || i+1 == len(s) { return nil } i++ diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go index f1a4716d227..2008953ef38 100644 --- a/libgo/go/net/ip_test.go +++ b/libgo/go/net/ip_test.go @@ -29,6 +29,7 @@ var parseiptests = []struct { {"127.0.0.1", IPv4(127, 0, 0, 1)}, {"127.0.0.256", nil}, {"abc", nil}, + {"123:", nil}, {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, diff --git a/libgo/go/net/ipraw_test.go b/libgo/go/net/ipraw_test.go index 562298bdf46..0c0b675f875 100644 --- a/libgo/go/net/ipraw_test.go +++ b/libgo/go/net/ipraw_test.go @@ -60,7 +60,8 @@ func parsePingReply(p []byte) (id, seq int) { } var srchost = flag.String("srchost", "", "Source of the ICMP ECHO request") -var dsthost = flag.String("dsthost", "localhost", "Destination for the ICMP ECHO request") +// 127.0.0.1 because this is an IPv4-specific test. +var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO request") // test (raw) IP socket using ICMP func TestICMP(t *testing.T) { @@ -69,9 +70,12 @@ func TestICMP(t *testing.T) { return } - var laddr *IPAddr + var ( + laddr *IPAddr + err os.Error + ) if *srchost != "" { - laddr, err := ResolveIPAddr(*srchost) + laddr, err = ResolveIPAddr(*srchost) if err != nil { t.Fatalf(`net.ResolveIPAddr("%v") = %v, %v`, *srchost, laddr, err) } diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index 60433303ae1..5be6fe4e0b9 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -245,7 +245,7 @@ func hostToIP(host string) (ip IP, err os.Error) { err = err1 goto Error } - addr = firstSupportedAddr(addrs) + addr = firstSupportedAddr(anyaddr, addrs) if addr == nil { // should not happen err = &AddrError{"LookupHost returned invalid address", addrs[0]} diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go index 80bc3eea5da..e8bcac64603 100644 --- a/libgo/go/net/ipsock.go +++ b/libgo/go/net/ipsock.go @@ -35,15 +35,28 @@ func kernelSupportsIPv6() bool { var preferIPv4 = !kernelSupportsIPv6() -func firstSupportedAddr(addrs []string) (addr IP) { +func firstSupportedAddr(filter func(IP) IP, addrs []string) IP { for _, s := range addrs { - addr = ParseIP(s) - if !preferIPv4 || addr.To4() != nil { - break + if addr := filter(ParseIP(s)); addr != nil { + return addr } - addr = nil } - return addr + return nil +} + +func anyaddr(x IP) IP { return x } +func ipv4only(x IP) IP { return x.To4() } + +func ipv6only(x IP) IP { + // Only return addresses that we can use + // with the kernel's IPv6 addressing modes. + // If preferIPv4 is set, it means the IPv6 stack + // cannot take IPv4 addresses directly (we prefer + // to use the IPv4 stack) so reject IPv4 addresses. + if x.To4() != nil && preferIPv4 { + return nil + } + return x } // TODO(rsc): if syscall.OS == "linux", we're supposd to read @@ -131,7 +144,6 @@ func (e InvalidAddrError) String() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } - func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { switch family { case syscall.AF_INET: @@ -218,13 +230,31 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { // Try as an IP address. addr = ParseIP(host) if addr == nil { + filter := anyaddr + if len(net) >= 4 && net[3] == '4' { + filter = ipv4only + } else if len(net) >= 4 && net[3] == '6' { + filter = ipv6only + } // Not an IP address. Try as a DNS name. addrs, err1 := LookupHost(host) if err1 != nil { err = err1 goto Error } - addr = firstSupportedAddr(addrs) + if filter == anyaddr { + // We'll take any IP address, but since the dialing code + // does not yet try multiple addresses, prefer to use + // an IPv4 address if possible. This is especially relevant + // if localhost resolves to [ipv6-localhost, ipv4-localhost]. + // Too much code assumes localhost == ipv4-localhost. + addr = firstSupportedAddr(ipv4only, addrs) + if addr == nil { + addr = firstSupportedAddr(anyaddr, addrs) + } + } else { + addr = firstSupportedAddr(filter, addrs) + } if addr == nil { // should not happen err = &AddrError{"LookupHost returned invalid address", addrs[0]} diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go index 7b2185ed419..eeb22a8ae3d 100644 --- a/libgo/go/net/lookup.go +++ b/libgo/go/net/lookup.go @@ -36,3 +36,15 @@ func LookupPort(network, service string) (port int, err os.Error) { } return } + +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err os.Error) { + cname, err, ok := cgoLookupCNAME(name) + if !ok { + cname, err = goLookupCNAME(name) + } + return +} diff --git a/libgo/go/net/multicast_test.go b/libgo/go/net/multicast_test.go index 32fdec85bde..be6dbf2dc19 100644 --- a/libgo/go/net/multicast_test.go +++ b/libgo/go/net/multicast_test.go @@ -5,14 +5,21 @@ package net import ( + "flag" "runtime" "testing" ) +var multicast = flag.Bool("multicast", false, "enable multicast tests") + func TestMulticastJoinAndLeave(t *testing.T) { if runtime.GOOS == "windows" { return } + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } addr := &UDPAddr{ IP: IPv4zero, @@ -40,6 +47,10 @@ func TestMulticastJoinAndLeave(t *testing.T) { } func TestJoinFailureWithIPv6Address(t *testing.T) { + if !*multicast { + t.Logf("test disabled; use --multicast to enable") + return + } addr := &UDPAddr{ IP: IPv4zero, Port: 0, diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index 04a898a9aac..51db1073954 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The net package provides a portable interface to Unix -// networks sockets, including TCP/IP, UDP, domain name -// resolution, and Unix domain sockets. +// Package net provides a portable interface to Unix networks sockets, +// including TCP/IP, UDP, domain name resolution, and Unix domain sockets. package net // TODO(rsc): diff --git a/libgo/go/net/resolv_windows.go b/libgo/go/net/resolv_windows.go index 000c3065911..f7c3f51bef1 100644 --- a/libgo/go/net/resolv_windows.go +++ b/libgo/go/net/resolv_windows.go @@ -47,7 +47,7 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { return addrs, nil } -func LookupCNAME(name string) (cname string, err os.Error) { +func goLookupCNAME(name string) (cname string, err os.Error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) if int(e) != 0 { @@ -113,6 +113,10 @@ func reverseaddr(addr string) (arpa string, err os.Error) { panic("unimplemented") } +func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { + panic("unimplemented") +} + // DNSError represents a DNS lookup error. type DNSError struct { Error string // description of the error diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go index 37695a068d1..075748b83b0 100644 --- a/libgo/go/net/server_test.go +++ b/libgo/go/net/server_test.go @@ -108,12 +108,10 @@ func doTest(t *testing.T, network, listenaddr, dialaddr string) { } func TestTCPServer(t *testing.T) { - doTest(t, "tcp", "0.0.0.0", "127.0.0.1") - doTest(t, "tcp", "", "127.0.0.1") + doTest(t, "tcp", "127.0.0.1", "127.0.0.1") if kernelSupportsIPv6() { - doTest(t, "tcp", "[::]", "[::ffff:127.0.0.1]") - doTest(t, "tcp", "[::]", "127.0.0.1") - doTest(t, "tcp", "0.0.0.0", "[::ffff:127.0.0.1]") + doTest(t, "tcp", "[::1]", "[::1]") + doTest(t, "tcp", "127.0.0.1", "[::ffff:127.0.0.1]") } } diff --git a/libgo/go/net/sock.go b/libgo/go/net/sock.go index 933700af160..21bd5f03e89 100644 --- a/libgo/go/net/sock.go +++ b/libgo/go/net/sock.go @@ -32,17 +32,7 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal syscall.CloseOnExec(s) syscall.ForkLock.RUnlock() - // Allow reuse of recently-used addresses. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - - // Allow broadcast. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - - if f == syscall.AF_INET6 { - // using ip, tcp, udp, etc. - // allow both protocols even if the OS default is otherwise. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) - } + setKernelSpecificSockopt(s, f) if la != nil { e = syscall.Bind(s, la) @@ -161,7 +151,7 @@ type UnknownSocketError struct { } func (e *UnknownSocketError) String() string { - return "unknown socket address type " + reflect.Typeof(e.sa).String() + return "unknown socket address type " + reflect.TypeOf(e.sa).String() } func sockaddrToString(sa syscall.Sockaddr) (name string, err os.Error) { diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go new file mode 100644 index 00000000000..5fd52074ad3 --- /dev/null +++ b/libgo/go/net/sock_bsd.go @@ -0,0 +1,31 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets for BSD variants + +package net + +import ( + "syscall" +) + +func setKernelSpecificSockopt(s, f int) { + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow reuse of recently-used ports. + // This option is supported only in descendants of 4.4BSD, + // to make an effective multicast application and an application + // that requires quick draw possible. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } +} diff --git a/libgo/go/net/sock_linux.go b/libgo/go/net/sock_linux.go new file mode 100644 index 00000000000..ec31e803b6f --- /dev/null +++ b/libgo/go/net/sock_linux.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets for Linux + +package net + +import ( + "syscall" +) + +func setKernelSpecificSockopt(s, f int) { + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } +} diff --git a/libgo/go/net/sock_windows.go b/libgo/go/net/sock_windows.go new file mode 100644 index 00000000000..e17c60b98b6 --- /dev/null +++ b/libgo/go/net/sock_windows.go @@ -0,0 +1,25 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets for Windows + +package net + +import ( + "syscall" +) + +func setKernelSpecificSockopt(s, f int) { + // Allow reuse of recently-used addresses and ports. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } +} diff --git a/libgo/go/net/srv_test.go b/libgo/go/net/srv_test.go index 4dd6089cdd2..f1c7a0ab498 100644 --- a/libgo/go/net/srv_test.go +++ b/libgo/go/net/srv_test.go @@ -8,10 +8,17 @@ package net import ( + "runtime" "testing" ) +var avoidMacFirewall = runtime.GOOS == "darwin" + func TestGoogleSRV(t *testing.T) { + if testing.Short() || avoidMacFirewall { + t.Logf("skipping test to avoid external network") + return + } _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") if err != nil { t.Errorf("failed: %s", err) diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index b484be20b46..d9aa7cf19a5 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -62,8 +62,8 @@ func (a *TCPAddr) toAddr() sockaddr { // host:port and resolves domain names or port names to // numeric addresses. A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveTCPAddr(addr string) (*TCPAddr, os.Error) { - ip, port, err := hostPortToIP("tcp", addr) +func ResolveTCPAddr(network, addr string) (*TCPAddr, os.Error) { + ip, port, err := hostPortToIP(network, addr) if err != nil { return nil, err } diff --git a/libgo/go/net/textproto/textproto.go b/libgo/go/net/textproto/textproto.go index fbfad9d61ce..9f19b5495d1 100644 --- a/libgo/go/net/textproto/textproto.go +++ b/libgo/go/net/textproto/textproto.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The textproto package implements generic support for -// text-based request/response protocols in the style of -// HTTP, NNTP, and SMTP. +// Package textproto implements generic support for text-based request/response +// protocols in the style of HTTP, NNTP, and SMTP. // // The package provides: // diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 44d618dab08..67684471b72 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -62,8 +62,8 @@ func (a *UDPAddr) toAddr() sockaddr { // host:port and resolves domain names or port names to // numeric addresses. A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveUDPAddr(addr string) (*UDPAddr, os.Error) { - ip, port, err := hostPortToIP("udp", addr) +func ResolveUDPAddr(network, addr string) (*UDPAddr, os.Error) { + ip, port, err := hostPortToIP(network, addr) if err != nil { return nil, err } diff --git a/libgo/go/netchan/common.go b/libgo/go/netchan/common.go index d2cd8efc559..a319391bf16 100644 --- a/libgo/go/netchan/common.go +++ b/libgo/go/netchan/common.go @@ -73,7 +73,7 @@ type unackedCounter interface { // A channel and its direction. type chanDir struct { - ch *reflect.ChanValue + ch reflect.Value dir Dir } diff --git a/libgo/go/netchan/export.go b/libgo/go/netchan/export.go index e91e777e306..1e5ccdb5cba 100644 --- a/libgo/go/netchan/export.go +++ b/libgo/go/netchan/export.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The netchan package implements type-safe networked channels: + Package netchan implements type-safe networked channels: it allows the two ends of a channel to appear on different computers connected by a network. It does this by transporting data sent to a channel on one machine so it can be recovered @@ -111,9 +111,9 @@ func (client *expClient) getChan(hdr *header, dir Dir) *netChan { // data arrives from the client. func (client *expClient) run() { hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) req := new(request) - reqValue := reflect.NewValue(req) + reqValue := reflect.ValueOf(req) error := new(error) for { *hdr = header{} @@ -221,7 +221,7 @@ func (client *expClient) serveSend(hdr header) { return } // Create a new value for each received item. - val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem()) + val := reflect.New(nch.ch.Type().Elem()).Elem() if err := client.decode(val); err != nil { expLog("value decode:", err, "; type ", nch.ch.Type()) return @@ -340,26 +340,26 @@ func (exp *Exporter) Sync(timeout int64) os.Error { return exp.clientSet.sync(timeout) } -func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) { - chanType, ok := reflect.Typeof(chT).(*reflect.ChanType) - if !ok { - return nil, os.ErrorString("not a channel") +func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { + chanType := reflect.TypeOf(chT) + if chanType.Kind() != reflect.Chan { + return reflect.Value{}, os.ErrorString("not a channel") } if dir != Send && dir != Recv { - return nil, os.ErrorString("unknown channel direction") + return reflect.Value{}, os.ErrorString("unknown channel direction") } - switch chanType.Dir() { + switch chanType.ChanDir() { case reflect.BothDir: case reflect.SendDir: if dir != Recv { - return nil, os.ErrorString("to import/export with Send, must provide <-chan") + return reflect.Value{}, os.ErrorString("to import/export with Send, must provide <-chan") } case reflect.RecvDir: if dir != Send { - return nil, os.ErrorString("to import/export with Recv, must provide chan<-") + return reflect.Value{}, os.ErrorString("to import/export with Recv, must provide chan<-") } } - return reflect.NewValue(chT).(*reflect.ChanValue), nil + return reflect.ValueOf(chT), nil } // Export exports a channel of a given type and specified direction. The diff --git a/libgo/go/netchan/import.go b/libgo/go/netchan/import.go index 8ba5df9a515..0a700ca2b99 100644 --- a/libgo/go/netchan/import.go +++ b/libgo/go/netchan/import.go @@ -73,10 +73,10 @@ func (imp *Importer) shutdown() { func (imp *Importer) run() { // Loop on responses; requests are sent by ImportNValues() hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) ackHdr := new(header) err := new(error) - errValue := reflect.NewValue(err) + errValue := reflect.ValueOf(err) for { *hdr = header{} if e := imp.decode(hdrValue); e != nil { @@ -133,7 +133,7 @@ func (imp *Importer) run() { ackHdr.SeqNum = hdr.SeqNum imp.encode(ackHdr, payAck, nil) // Create a new value for each received item. - value := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem()) + value := reflect.New(nch.ch.Type().Elem()).Elem() if e := imp.decode(value); e != nil { impLog("importer value decode:", e) return diff --git a/libgo/go/os/dir_plan9.go b/libgo/go/os/dir_plan9.go index 7bb0642e479..d9514191d79 100644 --- a/libgo/go/os/dir_plan9.go +++ b/libgo/go/os/dir_plan9.go @@ -8,72 +8,56 @@ import ( "syscall" ) -type dirInfo int - -var markDirectory dirInfo = ^0 - // Readdir reads the contents of the directory associated with file and -// returns an array of up to count FileInfo structures, as would be returned -// by Lstat, in directory order. Subsequent calls on the same file will yield -// further FileInfos. A negative count means to read the entire directory. +// returns an array of up to count FileInfo structures, in directory order. +// Subsequent calls on the same file will yield further FileInfos. +// A negative count means to read until EOF. // Readdir returns the array and an Error, if any. func (file *File) Readdir(count int) (fi []FileInfo, err Error) { // If this file has no dirinfo, create one. if file.dirinfo == nil { - file.dirinfo = &markDirectory + file.dirinfo = new(dirInfo) } - + d := file.dirinfo size := count if size < 0 { size = 100 } - - result := make([]FileInfo, 0, size) - var buf [syscall.STATMAX]byte - - for { - n, e := file.Read(buf[:]) - - if e != nil { + result := make([]FileInfo, 0, size) // Empty with room to grow. + for count != 0 { + // Refill the buffer if necessary + if d.bufp >= d.nbuf { + d.bufp = 0 + var e Error + d.nbuf, e = file.Read(d.buf[:]) + if e != nil && e != EOF { + return nil, &PathError{"readdir", file.name, e} + } if e == EOF { break } - - return []FileInfo{}, &PathError{"readdir", file.name, e} + if d.nbuf < syscall.STATFIXLEN { + return nil, &PathError{"readdir", file.name, Eshortstat} + } } - if n < syscall.STATFIXLEN { - return []FileInfo{}, &PathError{"readdir", file.name, Eshortstat} + // Get a record from buffer + m, _ := gbit16(d.buf[d.bufp:]) + m += 2 + if m < syscall.STATFIXLEN { + return nil, &PathError{"readdir", file.name, Eshortstat} } - - for i := 0; i < n; { - m, _ := gbit16(buf[i:]) - m += 2 - - if m < syscall.STATFIXLEN { - return []FileInfo{}, &PathError{"readdir", file.name, Eshortstat} - } - - d, e := UnmarshalDir(buf[i : i+int(m)]) - - if e != nil { - return []FileInfo{}, &PathError{"readdir", file.name, e} - } - - var f FileInfo - fileInfoFromStat(&f, d) - - result = append(result, f) - - // a negative count means to read until EOF. - if count > 0 && len(result) >= count { - break - } - - i += int(m) + dir, e := UnmarshalDir(d.buf[d.bufp : d.bufp+int(m)]) + if e != nil { + return nil, &PathError{"readdir", file.name, e} } - } + var f FileInfo + fileInfoFromStat(&f, dir) + result = append(result, f) + d.bufp += int(m) + count-- + } return result, nil } @@ -90,7 +74,7 @@ func (file *File) Readdirnames(count int) (names []string, err Error) { names = make([]string, len(fi)) err = nil - for i, _ := range fi { + for i := range fi { names[i] = fi[i].Name } diff --git a/libgo/go/os/env.go b/libgo/go/os/env.go index 3a6d79dd095..3772c090b8f 100644 --- a/libgo/go/os/env.go +++ b/libgo/go/os/env.go @@ -6,6 +6,8 @@ package os +func setenv_c(k, v string) + // Expand replaces ${var} or $var in the string based on the mapping function. // Invocations of undefined variables are replaced with the empty string. func Expand(s string, mapping func(string) string) string { diff --git a/libgo/go/os/env_unix.go b/libgo/go/os/env_unix.go index e7e1c3b90f1..8aa71e83a0c 100644 --- a/libgo/go/os/env_unix.go +++ b/libgo/go/os/env_unix.go @@ -29,6 +29,8 @@ func copyenv() { } } +var envLock sync.RWMutex + // Getenverror retrieves the value of the environment variable named by the key. // It returns the value and an error, if any. func Getenverror(key string) (value string, err Error) { @@ -37,6 +39,10 @@ func Getenverror(key string) (value string, err Error) { if len(key) == 0 { return "", EINVAL } + + envLock.RLock() + defer envLock.RUnlock() + v, ok := env[key] if !ok { return "", ENOENV @@ -55,35 +61,43 @@ func Getenv(key string) string { // It returns an Error, if any. func Setenv(key, value string) Error { once.Do(copyenv) - if len(key) == 0 { return EINVAL } + + envLock.Lock() + defer envLock.Unlock() + env[key] = value + setenv_c(key, value) // is a no-op if cgo isn't loaded return nil } // Clearenv deletes all environment variables. func Clearenv() { once.Do(copyenv) // prevent copyenv in Getenv/Setenv + + envLock.Lock() + defer envLock.Unlock() + env = make(map[string]string) + + // TODO(bradfitz): pass through to C } // Environ returns an array of strings representing the environment, // in the form "key=value". func Environ() []string { once.Do(copyenv) + envLock.RLock() + defer envLock.RUnlock() a := make([]string, len(env)) i := 0 for k, v := range env { - // check i < len(a) for safety, - // in case env is changing underfoot. - if i < len(a) { - a[i] = k + "=" + v - i++ - } + a[i] = k + "=" + v + i++ } - return a[0:i] + return a } // TempDir returns the default directory to use for temporary files. diff --git a/libgo/go/os/error_plan9.go b/libgo/go/os/error_plan9.go index d6575864e84..3374775b8e7 100644 --- a/libgo/go/os/error_plan9.go +++ b/libgo/go/os/error_plan9.go @@ -37,12 +37,15 @@ var ( Enonexist = NewError("file does not exist") Eexist = NewError("file already exists") Eio = NewError("i/o error") + Eperm = NewError("permission denied") EINVAL = Ebadarg ENOTDIR = Enotdir ENOENT = Enonexist EEXIST = Eexist EIO = Eio + EACCES = Eperm + EISDIR = syscall.EISDIR ENAMETOOLONG = NewError("file name too long") ERANGE = NewError("math result not representable") diff --git a/libgo/go/os/file.go b/libgo/go/os/file.go index 3aad8023453..dff8fa862ce 100644 --- a/libgo/go/os/file.go +++ b/libgo/go/os/file.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The os package provides a platform-independent interface to operating -// system functionality. The design is Unix-like. +// Package os provides a platform-independent interface to operating system +// functionality. The design is Unix-like. package os import ( "runtime" + "sync" "syscall" ) @@ -15,8 +16,9 @@ import ( type File struct { fd int name string - dirinfo *dirInfo // nil unless directory being read - nepipe int // number of consecutive EPIPE in Write + dirinfo *dirInfo // nil unless directory being read + nepipe int // number of consecutive EPIPE in Write + l sync.Mutex // used to implement windows pread/pwrite } // Fd returns the integer Unix file descriptor referencing the open file. @@ -30,7 +32,7 @@ func NewFile(fd int, name string) *File { if fd < 0 { return nil } - f := &File{fd, name, nil, 0} + f := &File{fd: fd, name: name} runtime.SetFinalizer(f, (*File).Close) return f } @@ -85,7 +87,7 @@ func (file *File) Read(b []byte) (n int, err Error) { if file == nil { return 0, EINVAL } - n, e := syscall.Read(file.fd, b) + n, e := file.read(b) if n < 0 { n = 0 } @@ -107,7 +109,7 @@ func (file *File) ReadAt(b []byte, off int64) (n int, err Error) { return 0, EINVAL } for len(b) > 0 { - m, e := syscall.Pread(file.fd, b, off) + m, e := file.pread(b, off) if m == 0 && !iserror(e) { return n, EOF } @@ -129,7 +131,7 @@ func (file *File) Write(b []byte) (n int, err Error) { if file == nil { return 0, EINVAL } - n, e := syscall.Write(file.fd, b) + n, e := file.write(b) if n < 0 { n = 0 } @@ -150,7 +152,7 @@ func (file *File) WriteAt(b []byte, off int64) (n int, err Error) { return 0, EINVAL } for len(b) > 0 { - m, e := syscall.Pwrite(file.fd, b, off) + m, e := file.pwrite(b, off) if iserror(e) { err = &PathError{"write", file.name, Errno(e)} break @@ -167,7 +169,7 @@ func (file *File) WriteAt(b []byte, off int64) (n int, err Error) { // relative to the current offset, and 2 means relative to the end. // It returns the new offset and an Error, if any. func (file *File) Seek(offset int64, whence int) (ret int64, err Error) { - r, e := syscall.Seek(file.fd, offset, whence) + r, e := file.seek(offset, whence) if !iserror(e) && file.dirinfo != nil && r != 0 { e = syscall.EISDIR } diff --git a/libgo/go/os/file_plan9.go b/libgo/go/os/file_plan9.go index b79256c51ed..7b473f80221 100644 --- a/libgo/go/os/file_plan9.go +++ b/libgo/go/os/file_plan9.go @@ -9,6 +9,13 @@ import ( "syscall" ) +// Auxiliary information if the File describes a directory +type dirInfo struct { + buf [syscall.STATMAX]byte // buffer for directory I/O + nbuf int // length of buf; return value from Read + bufp int // location of next record in buf. +} + func epipecheck(file *File, e syscall.Error) { } @@ -110,6 +117,39 @@ func (f *File) Sync() (err Error) { return nil } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err syscall.Error) { + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to nil. +func (f *File) pread(b []byte, off int64) (n int, err syscall.Error) { + return syscall.Pread(f.fd, b, off) +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err syscall.Error) { + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err syscall.Error) { + return syscall.Pwrite(f.fd, b, off) +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err syscall.Error) { + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/libgo/go/os/file_posix.go b/libgo/go/os/file_posix.go index 5151df49873..f1191d61feb 100644 --- a/libgo/go/os/file_posix.go +++ b/libgo/go/os/file_posix.go @@ -10,11 +10,13 @@ import ( "syscall" ) +func sigpipe() // implemented in package runtime + func epipecheck(file *File, e int) { if e == syscall.EPIPE { file.nepipe++ if file.nepipe >= 10 { - Exit(syscall.EPIPE) + sigpipe() } } else { file.nepipe = 0 diff --git a/libgo/go/os/file_unix.go b/libgo/go/os/file_unix.go index f80f1d538ef..4c69cc83802 100644 --- a/libgo/go/os/file_unix.go +++ b/libgo/go/os/file_unix.go @@ -102,6 +102,39 @@ func (file *File) Readdir(count int) (fi []FileInfo, err Error) { return } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err int) { + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to 0. +func (f *File) pread(b []byte, off int64) (n int, err int) { + return syscall.Pread(f.fd, b, off) +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err int) { + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err int) { + return syscall.Pwrite(f.fd, b, off) +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err int) { + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/libgo/go/os/inotify/inotify_linux.go b/libgo/go/os/inotify/inotify_linux.go index 96c229e7b74..7c7b7698feb 100644 --- a/libgo/go/os/inotify/inotify_linux.go +++ b/libgo/go/os/inotify/inotify_linux.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -This package implements a wrapper for the Linux inotify system. +Package inotify implements a wrapper for the Linux inotify system. Example: watcher, err := inotify.NewWatcher() @@ -109,7 +109,7 @@ func (w *Watcher) AddWatch(path string, flags uint32) os.Error { } wd, errno := syscall.InotifyAddWatch(w.fd, path, flags) if wd == -1 { - return os.NewSyscallError("inotify_add_watch", errno) + return &os.PathError{"inotify_add_watch", path, os.Errno(errno)} } if !found { diff --git a/libgo/go/os/inotify/inotify_linux_test.go b/libgo/go/os/inotify/inotify_linux_test.go index f5d1f8384df..e29a46d6c2d 100644 --- a/libgo/go/os/inotify/inotify_linux_test.go +++ b/libgo/go/os/inotify/inotify_linux_test.go @@ -17,8 +17,8 @@ func TestInotifyEvents(t *testing.T) { t.Fatalf("NewWatcher() failed: %s", err) } - // Add a watch for "_obj" - err = watcher.Watch("_obj") + // Add a watch for "_test" + err = watcher.Watch("_test") if err != nil { t.Fatalf("Watcher.Watch() failed: %s", err) } @@ -30,7 +30,7 @@ func TestInotifyEvents(t *testing.T) { } }() - const testFile string = "_obj/TestInotifyEvents.testfile" + const testFile string = "_test/TestInotifyEvents.testfile" // Receive events on the event channel on a separate goroutine eventstream := watcher.Event @@ -90,7 +90,7 @@ func TestInotifyClose(t *testing.T) { t.Fatal("double Close() test failed: second Close() call didn't return") } - err := watcher.Watch("_obj") + err := watcher.Watch("_test") if err == nil { t.Fatal("expected error on Watch() after Close(), got nil") } diff --git a/libgo/go/os/os_test.go b/libgo/go/os/os_test.go index 607a11be334..8b8a8434970 100644 --- a/libgo/go/os/os_test.go +++ b/libgo/go/os/os_test.go @@ -43,6 +43,14 @@ var sysdir = func() (sd *sysDir) { "services", }, } + case "plan9": + sd = &sysDir{ + "/lib/ndb", + []string{ + "common", + "local", + }, + } default: sd = &sysDir{ "/etc", @@ -243,8 +251,11 @@ func smallReaddirnames(file *File, length int, t *testing.T) []string { func TestReaddirnamesOneAtATime(t *testing.T) { // big directory that doesn't change often. dir := "/usr/bin" - if syscall.OS == "windows" { + switch syscall.OS { + case "windows": dir = Getenv("SystemRoot") + "\\system32" + case "plan9": + dir = "/bin" } file, err := Open(dir) defer file.Close() @@ -260,6 +271,9 @@ func TestReaddirnamesOneAtATime(t *testing.T) { t.Fatalf("open %q failed: %v", dir, err2) } small := smallReaddirnames(file1, len(all)+100, t) // +100 in case we screw up + if len(small) < len(all) { + t.Fatalf("len(small) is %d, less than %d", len(small), len(all)) + } for i, n := range all { if small[i] != n { t.Errorf("small read %q mismatch: %v", small[i], n) @@ -551,8 +565,8 @@ func checkSize(t *testing.T, f *File, size int64) { } } -func TestTruncate(t *testing.T) { - f := newFile("TestTruncate", t) +func TestFTruncate(t *testing.T) { + f := newFile("TestFTruncate", t) defer Remove(f.Name()) defer f.Close() @@ -569,6 +583,24 @@ func TestTruncate(t *testing.T) { checkSize(t, f, 13+9) // wrote at offset past where hello, world was. } +func TestTruncate(t *testing.T) { + f := newFile("TestTruncate", t) + defer Remove(f.Name()) + defer f.Close() + + checkSize(t, f, 0) + f.Write([]byte("hello, world\n")) + checkSize(t, f, 13) + Truncate(f.Name(), 10) + checkSize(t, f, 10) + Truncate(f.Name(), 1024) + checkSize(t, f, 1024) + Truncate(f.Name(), 0) + checkSize(t, f, 0) + f.Write([]byte("surprise!")) + checkSize(t, f, 13+9) // wrote at offset past where hello, world was. +} + // Use TempDir() to make sure we're on a local file system, // so that timings are not distorted by latency and caching. // On NFS, timings can be off due to caching of meta-data on @@ -870,6 +902,18 @@ func TestAppend(t *testing.T) { if s != "new|append" { t.Fatalf("writeFile: have %q want %q", s, "new|append") } + s = writeFile(t, f, O_CREATE|O_APPEND|O_RDWR, "|append") + if s != "new|append|append" { + t.Fatalf("writeFile: have %q want %q", s, "new|append|append") + } + err := Remove(f) + if err != nil { + t.Fatalf("Remove: %v", err) + } + s = writeFile(t, f, O_CREATE|O_APPEND|O_RDWR, "new&append") + if s != "new&append" { + t.Fatalf("writeFile: have %q want %q", s, "new&append") + } } func TestStatDirWithTrailingSlash(t *testing.T) { diff --git a/libgo/go/os/user/lookup_stubs.go b/libgo/go/os/user/lookup_stubs.go new file mode 100644 index 00000000000..2f08f70fd57 --- /dev/null +++ b/libgo/go/os/user/lookup_stubs.go @@ -0,0 +1,19 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "os" + "runtime" +) + +func Lookup(username string) (*User, os.Error) { + return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func LookupId(int) (*User, os.Error) { + return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/libgo/go/os/user/lookup_unix.go b/libgo/go/os/user/lookup_unix.go new file mode 100644 index 00000000000..7060530154a --- /dev/null +++ b/libgo/go/os/user/lookup_unix.go @@ -0,0 +1,89 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "os" + "strings" + "syscall" + "unsafe" +) + +/* +#include +#include +#include +#include + +static int mygetpwuid_r(int uid, struct passwd *pwd, + char *buf, size_t buflen, struct passwd **result) { + return getpwuid_r(uid, pwd, buf, buflen, result); +} +*/ + +func libc_getpwnam_r(name *byte, pwd *syscall.Passwd, buf *byte, buflen syscall.Size_t, result **syscall.Passwd) int __asm__ ("getpwnam_r") +func libc_getpwuid_r(uid syscall.Uid_t, pwd *syscall.Passwd, buf *byte, buflen syscall.Size_t, result **syscall.Passwd) int __asm__ ("getpwuid_r") + +// Lookup looks up a user by username. If the user cannot be found, +// the returned error is of type UnknownUserError. +func Lookup(username string) (*User, os.Error) { + return lookup(-1, username, true) +} + +// LookupId looks up a user by userid. If the user cannot be found, +// the returned error is of type UnknownUserIdError. +func LookupId(uid int) (*User, os.Error) { + return lookup(uid, "", false) +} + +func lookup(uid int, username string, lookupByName bool) (*User, os.Error) { + var pwd syscall.Passwd + var result *syscall.Passwd + + // FIXME: Should let buf grow if necessary. + const bufSize = 1024 + buf := make([]byte, bufSize) + if lookupByName { + rv := libc_getpwnam_r(syscall.StringBytePtr(username), + &pwd, + &buf[0], + bufSize, + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup username %s: %s", username, os.Errno(syscall.GetErrno())) + } + if result == nil { + return nil, UnknownUserError(username) + } + } else { + rv := libc_getpwuid_r(syscall.Uid_t(uid), + &pwd, + &buf[0], + bufSize, + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup userid %d: %s", uid, os.Errno(syscall.GetErrno())) + } + if result == nil { + return nil, UnknownUserIdError(uid) + } + } + u := &User{ + Uid: int(pwd.Pw_uid), + Gid: int(pwd.Pw_gid), + Username: syscall.BytePtrToString((*byte)(unsafe.Pointer(pwd.Pw_name))), + Name: syscall.BytePtrToString((*byte)(unsafe.Pointer(pwd.Pw_gecos))), + HomeDir: syscall.BytePtrToString((*byte)(unsafe.Pointer(pwd.Pw_dir))), + } + // The pw_gecos field isn't quite standardized. Some docs + // say: "It is expected to be a comma separated list of + // personal data where the first item is the full name of the + // user." + if i := strings.Index(u.Name, ","); i >= 0 { + u.Name = u.Name[:i] + } + return u, nil +} diff --git a/libgo/go/os/user/user.go b/libgo/go/os/user/user.go new file mode 100644 index 00000000000..dd009211d76 --- /dev/null +++ b/libgo/go/os/user/user.go @@ -0,0 +1,35 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package user allows user account lookups by name or id. +package user + +import ( + "strconv" +) + +// User represents a user account. +type User struct { + Uid int // user id + Gid int // primary group id + Username string + Name string + HomeDir string +} + +// UnknownUserIdError is returned by LookupId when +// a user cannot be found. +type UnknownUserIdError int + +func (e UnknownUserIdError) String() string { + return "user: unknown userid " + strconv.Itoa(int(e)) +} + +// UnknownUserError is returned by Lookup when +// a user cannot be found. +type UnknownUserError string + +func (e UnknownUserError) String() string { + return "user: unknown user " + string(e) +} diff --git a/libgo/go/os/user/user_test.go b/libgo/go/os/user/user_test.go new file mode 100644 index 00000000000..2c142bf1817 --- /dev/null +++ b/libgo/go/os/user/user_test.go @@ -0,0 +1,61 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "os" + "reflect" + "runtime" + "syscall" + "testing" +) + +func skip(t *testing.T) bool { + if runtime.GOARCH == "arm" { + t.Logf("user: cgo not implemented on arm; skipping tests") + return true + } + + if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" || runtime.GOOS == "darwin" { + return false + } + + t.Logf("user: Lookup not implemented on %s; skipping test", runtime.GOOS) + return true +} + +func TestLookup(t *testing.T) { + if skip(t) { + return + } + + // Test LookupId on the current user + uid := syscall.Getuid() + u, err := LookupId(uid) + if err != nil { + t.Fatalf("LookupId: %v", err) + } + if e, g := uid, u.Uid; e != g { + t.Errorf("expected Uid of %d; got %d", e, g) + } + fi, err := os.Stat(u.HomeDir) + if err != nil || !fi.IsDirectory() { + t.Errorf("expected a valid HomeDir; stat(%q): err=%v, IsDirectory=%v", err, fi.IsDirectory()) + } + if u.Username == "" { + t.Fatalf("didn't get a username") + } + + // Test Lookup by username, using the username from LookupId + un, err := Lookup(u.Username) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if !reflect.DeepEqual(u, un) { + t.Errorf("Lookup by userid vs. name didn't match\n"+ + "LookupId(%d): %#v\n"+ + "Lookup(%q): %#v\n",uid, u, u.Username, un) + } +} diff --git a/libgo/go/path/filepath/path.go b/libgo/go/path/filepath/path.go index de673a72577..541a233066a 100644 --- a/libgo/go/path/filepath/path.go +++ b/libgo/go/path/filepath/path.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The filepath package implements utility routines for manipulating -// filename paths in a way compatible with the target operating -// system-defined file paths. +// Package filepath implements utility routines for manipulating filename paths +// in a way compatible with the target operating system-defined file paths. package filepath import ( diff --git a/libgo/go/path/path.go b/libgo/go/path/path.go index 658eec09387..235384667c6 100644 --- a/libgo/go/path/path.go +++ b/libgo/go/path/path.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The path package implements utility routines for manipulating -// slash-separated filename paths. +// Package path implements utility routines for manipulating slash-separated +// filename paths. package path import ( diff --git a/libgo/go/reflect/all_test.go b/libgo/go/reflect/all_test.go index ac07ce5a368..145fd520f5d 100644 --- a/libgo/go/reflect/all_test.go +++ b/libgo/go/reflect/all_test.go @@ -5,11 +5,13 @@ package reflect_test import ( + "bytes" "container/vector" "fmt" "io" "os" . "reflect" +/* "runtime" */ "testing" "unsafe" ) @@ -35,7 +37,7 @@ func assert(t *testing.T, s, want string) { } } -func typestring(i interface{}) string { return Typeof(i).String() } +func typestring(i interface{}) string { return TypeOf(i).String() } var typeTests = []pair{ {struct{ x int }{}, "int"}, @@ -150,50 +152,50 @@ var typeTests = []pair{ b() }) }{}, - "interface { a(func(func(int) int) func(func(int)) int); b() }", + "interface { reflect_test.a(func(func(int) int) func(func(int)) int); reflect_test.b() }", }, } var valueTests = []pair{ - {(int8)(0), "8"}, - {(int16)(0), "16"}, - {(int32)(0), "32"}, - {(int64)(0), "64"}, - {(uint8)(0), "8"}, - {(uint16)(0), "16"}, - {(uint32)(0), "32"}, - {(uint64)(0), "64"}, - {(float32)(0), "256.25"}, - {(float64)(0), "512.125"}, - {(string)(""), "stringy cheese"}, - {(bool)(false), "true"}, - {(*int8)(nil), "*int8(0)"}, - {(**int8)(nil), "**int8(0)"}, - {[5]int32{}, "[5]int32{0, 0, 0, 0, 0}"}, - {(**integer)(nil), "**reflect_test.integer(0)"}, - {(map[string]int32)(nil), "map[string] int32{}"}, - {(chan<- string)(nil), "chan<- string"}, - {struct { + {new(int8), "8"}, + {new(int16), "16"}, + {new(int32), "32"}, + {new(int64), "64"}, + {new(uint8), "8"}, + {new(uint16), "16"}, + {new(uint32), "32"}, + {new(uint64), "64"}, + {new(float32), "256.25"}, + {new(float64), "512.125"}, + {new(string), "stringy cheese"}, + {new(bool), "true"}, + {new(*int8), "*int8(0)"}, + {new(**int8), "**int8(0)"}, + {new([5]int32), "[5]int32{0, 0, 0, 0, 0}"}, + {new(**integer), "**reflect_test.integer(0)"}, + {new(map[string]int32), "map[string] int32{}"}, + {new(chan<- string), "chan<- string"}, + {new(func(a int8, b int32)), "func(int8, int32)(0)"}, + {new(struct { c chan *int32 d float32 - }{}, + }), "struct { c chan *int32; d float32 }{chan *int32, 0}", }, - {(func(a int8, b int32))(nil), "func(int8, int32)(0)"}, - {struct{ c func(chan *integer, *int8) }{}, + {new(struct{ c func(chan *integer, *int8) }), "struct { c func(chan *reflect_test.integer, *int8) }{func(chan *reflect_test.integer, *int8)(0)}", }, - {struct { + {new(struct { a int8 b int32 - }{}, + }), "struct { a int8; b int32 }{0, 0}", }, - {struct { + {new(struct { a int8 b int8 c int32 - }{}, + }), "struct { a int8; b int8; c int32 }{0, 0, 0}", }, } @@ -207,58 +209,46 @@ func testType(t *testing.T, i int, typ Type, want string) { func TestTypes(t *testing.T) { for i, tt := range typeTests { - testType(t, i, NewValue(tt.i).(*StructValue).Field(0).Type(), tt.s) + testType(t, i, ValueOf(tt.i).Field(0).Type(), tt.s) } } func TestSet(t *testing.T) { for i, tt := range valueTests { - v := NewValue(tt.i) - switch v := v.(type) { - case *IntValue: - switch v.Type().Kind() { - case Int: - v.Set(132) - case Int8: - v.Set(8) - case Int16: - v.Set(16) - case Int32: - v.Set(32) - case Int64: - v.Set(64) - } - case *UintValue: - switch v.Type().Kind() { - case Uint: - v.Set(132) - case Uint8: - v.Set(8) - case Uint16: - v.Set(16) - case Uint32: - v.Set(32) - case Uint64: - v.Set(64) - } - case *FloatValue: - switch v.Type().Kind() { - case Float32: - v.Set(256.25) - case Float64: - v.Set(512.125) - } - case *ComplexValue: - switch v.Type().Kind() { - case Complex64: - v.Set(532.125 + 10i) - case Complex128: - v.Set(564.25 + 1i) - } - case *StringValue: - v.Set("stringy cheese") - case *BoolValue: - v.Set(true) + v := ValueOf(tt.i).Elem() + switch v.Kind() { + case Int: + v.SetInt(132) + case Int8: + v.SetInt(8) + case Int16: + v.SetInt(16) + case Int32: + v.SetInt(32) + case Int64: + v.SetInt(64) + case Uint: + v.SetUint(132) + case Uint8: + v.SetUint(8) + case Uint16: + v.SetUint(16) + case Uint32: + v.SetUint(32) + case Uint64: + v.SetUint(64) + case Float32: + v.SetFloat(256.25) + case Float64: + v.SetFloat(512.125) + case Complex64: + v.SetComplex(532.125 + 10i) + case Complex128: + v.SetComplex(564.25 + 1i) + case String: + v.SetString("stringy cheese") + case Bool: + v.SetBool(true) } s := valueToString(v) if s != tt.s { @@ -269,53 +259,40 @@ func TestSet(t *testing.T) { func TestSetValue(t *testing.T) { for i, tt := range valueTests { - v := NewValue(tt.i) - switch v := v.(type) { - case *IntValue: - switch v.Type().Kind() { - case Int: - v.SetValue(NewValue(int(132))) - case Int8: - v.SetValue(NewValue(int8(8))) - case Int16: - v.SetValue(NewValue(int16(16))) - case Int32: - v.SetValue(NewValue(int32(32))) - case Int64: - v.SetValue(NewValue(int64(64))) - } - case *UintValue: - switch v.Type().Kind() { - case Uint: - v.SetValue(NewValue(uint(132))) - case Uint8: - v.SetValue(NewValue(uint8(8))) - case Uint16: - v.SetValue(NewValue(uint16(16))) - case Uint32: - v.SetValue(NewValue(uint32(32))) - case Uint64: - v.SetValue(NewValue(uint64(64))) - } - case *FloatValue: - switch v.Type().Kind() { - case Float32: - v.SetValue(NewValue(float32(256.25))) - case Float64: - v.SetValue(NewValue(512.125)) - } - case *ComplexValue: - switch v.Type().Kind() { - case Complex64: - v.SetValue(NewValue(complex64(532.125 + 10i))) - case Complex128: - v.SetValue(NewValue(complex128(564.25 + 1i))) - } - - case *StringValue: - v.SetValue(NewValue("stringy cheese")) - case *BoolValue: - v.SetValue(NewValue(true)) + v := ValueOf(tt.i).Elem() + switch v.Kind() { + case Int: + v.Set(ValueOf(int(132))) + case Int8: + v.Set(ValueOf(int8(8))) + case Int16: + v.Set(ValueOf(int16(16))) + case Int32: + v.Set(ValueOf(int32(32))) + case Int64: + v.Set(ValueOf(int64(64))) + case Uint: + v.Set(ValueOf(uint(132))) + case Uint8: + v.Set(ValueOf(uint8(8))) + case Uint16: + v.Set(ValueOf(uint16(16))) + case Uint32: + v.Set(ValueOf(uint32(32))) + case Uint64: + v.Set(ValueOf(uint64(64))) + case Float32: + v.Set(ValueOf(float32(256.25))) + case Float64: + v.Set(ValueOf(512.125)) + case Complex64: + v.Set(ValueOf(complex64(532.125 + 10i))) + case Complex128: + v.Set(ValueOf(complex128(564.25 + 1i))) + case String: + v.Set(ValueOf("stringy cheese")) + case Bool: + v.Set(ValueOf(true)) } s := valueToString(v) if s != tt.s { @@ -341,7 +318,7 @@ var valueToStringTests = []pair{ func TestValueToString(t *testing.T) { for i, test := range valueToStringTests { - s := valueToString(NewValue(test.i)) + s := valueToString(ValueOf(test.i)) if s != test.s { t.Errorf("#%d: have %#q, want %#q", i, s, test.s) } @@ -349,16 +326,16 @@ func TestValueToString(t *testing.T) { } func TestArrayElemSet(t *testing.T) { - v := NewValue([10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - v.(*ArrayValue).Elem(4).(*IntValue).Set(123) + v := ValueOf(&[10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Elem() + v.Index(4).SetInt(123) s := valueToString(v) const want = "[10]int{1, 2, 3, 4, 123, 6, 7, 8, 9, 10}" if s != want { t.Errorf("[10]int: have %#q want %#q", s, want) } - v = NewValue([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - v.(*SliceValue).Elem(4).(*IntValue).Set(123) + v = ValueOf([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + v.Index(4).SetInt(123) s = valueToString(v) const want1 = "[]int{1, 2, 3, 4, 123, 6, 7, 8, 9, 10}" if s != want1 { @@ -369,16 +346,16 @@ func TestArrayElemSet(t *testing.T) { func TestPtrPointTo(t *testing.T) { var ip *int32 var i int32 = 1234 - vip := NewValue(&ip) - vi := NewValue(i) - vip.(*PtrValue).Elem().(*PtrValue).PointTo(vi) + vip := ValueOf(&ip) + vi := ValueOf(&i).Elem() + vip.Elem().Set(vi.Addr()) if *ip != 1234 { t.Errorf("got %d, want 1234", *ip) } ip = nil - vp := NewValue(ip).(*PtrValue) - vp.PointTo(vp.Elem()) + vp := ValueOf(&ip).Elem() + vp.Set(Zero(vp.Type())) if ip != nil { t.Errorf("got non-nil (%p), want nil", ip) } @@ -387,8 +364,8 @@ func TestPtrPointTo(t *testing.T) { func TestPtrSetNil(t *testing.T) { var i int32 = 1234 ip := &i - vip := NewValue(&ip) - vip.(*PtrValue).Elem().(*PtrValue).Set(nil) + vip := ValueOf(&ip) + vip.Elem().Set(Zero(vip.Elem().Type())) if ip != nil { t.Errorf("got non-nil (%d), want nil", *ip) } @@ -396,8 +373,8 @@ func TestPtrSetNil(t *testing.T) { func TestMapSetNil(t *testing.T) { m := make(map[string]int) - vm := NewValue(&m) - vm.(*PtrValue).Elem().(*MapValue).Set(nil) + vm := ValueOf(&m) + vm.Elem().Set(Zero(vm.Elem().Type())) if m != nil { t.Errorf("got non-nil (%p), want nil", m) } @@ -405,17 +382,17 @@ func TestMapSetNil(t *testing.T) { func TestAll(t *testing.T) { - testType(t, 1, Typeof((int8)(0)), "int8") - testType(t, 2, Typeof((*int8)(nil)).(*PtrType).Elem(), "int8") + testType(t, 1, TypeOf((int8)(0)), "int8") + testType(t, 2, TypeOf((*int8)(nil)).Elem(), "int8") - typ := Typeof((*struct { + typ := TypeOf((*struct { c chan *int32 d float32 })(nil)) testType(t, 3, typ, "*struct { c chan *int32; d float32 }") - etyp := typ.(*PtrType).Elem() + etyp := typ.Elem() testType(t, 4, etyp, "struct { c chan *int32; d float32 }") - styp := etyp.(*StructType) + styp := etyp f := styp.Field(0) testType(t, 5, f.Type, "chan *int32") @@ -430,60 +407,61 @@ func TestAll(t *testing.T) { t.Errorf("FieldByName says absent field is present") } - typ = Typeof([32]int32{}) + typ = TypeOf([32]int32{}) testType(t, 7, typ, "[32]int32") - testType(t, 8, typ.(*ArrayType).Elem(), "int32") + testType(t, 8, typ.Elem(), "int32") - typ = Typeof((map[string]*int32)(nil)) + typ = TypeOf((map[string]*int32)(nil)) testType(t, 9, typ, "map[string] *int32") - mtyp := typ.(*MapType) + mtyp := typ testType(t, 10, mtyp.Key(), "string") testType(t, 11, mtyp.Elem(), "*int32") - typ = Typeof((chan<- string)(nil)) + typ = TypeOf((chan<- string)(nil)) testType(t, 12, typ, "chan<- string") - testType(t, 13, typ.(*ChanType).Elem(), "string") + testType(t, 13, typ.Elem(), "string") // make sure tag strings are not part of element type - typ = Typeof(struct { + typ = TypeOf(struct { d []uint32 "TAG" - }{}).(*StructType).Field(0).Type + }{}).Field(0).Type testType(t, 14, typ, "[]uint32") } func TestInterfaceGet(t *testing.T) { var inter struct { - e interface{} + E interface{} } - inter.e = 123.456 - v1 := NewValue(&inter) - v2 := v1.(*PtrValue).Elem().(*StructValue).Field(0) + inter.E = 123.456 + v1 := ValueOf(&inter) + v2 := v1.Elem().Field(0) assert(t, v2.Type().String(), "interface { }") - i2 := v2.(*InterfaceValue).Interface() - v3 := NewValue(i2) + i2 := v2.Interface() + v3 := ValueOf(i2) assert(t, v3.Type().String(), "float64") } func TestInterfaceValue(t *testing.T) { var inter struct { - e interface{} + E interface{} } - inter.e = 123.456 - v1 := NewValue(&inter) - v2 := v1.(*PtrValue).Elem().(*StructValue).Field(0) + inter.E = 123.456 + v1 := ValueOf(&inter) + v2 := v1.Elem().Field(0) assert(t, v2.Type().String(), "interface { }") - v3 := v2.(*InterfaceValue).Elem() + v3 := v2.Elem() assert(t, v3.Type().String(), "float64") i3 := v2.Interface() if _, ok := i3.(float64); !ok { - t.Error("v2.Interface() did not return float64, got ", Typeof(i3)) + t.Error("v2.Interface() did not return float64, got ", TypeOf(i3)) } } func TestFunctionValue(t *testing.T) { - v := NewValue(func() {}) - if v.Interface() != v.Interface() { + var x interface{} = func() {} + v := ValueOf(x) + if v.Interface() != v.Interface() || v.Interface() != x { t.Fatalf("TestFunction != itself") } assert(t, v.Type().String(), "func()") @@ -496,6 +474,18 @@ var appendTests = []struct { {make([]int, 2, 4), []int{22, 33, 44}}, } +func sameInts(x, y []int) bool { + if len(x) != len(y) { + return false + } + for i, xx := range x { + if xx != y[i] { + return false + } + } + return true +} + func TestAppend(t *testing.T) { for i, test := range appendTests { origLen, extraLen := len(test.orig), len(test.extra) @@ -503,15 +493,15 @@ func TestAppend(t *testing.T) { // Convert extra from []int to []Value. e0 := make([]Value, len(test.extra)) for j, e := range test.extra { - e0[j] = NewValue(e) + e0[j] = ValueOf(e) } // Convert extra from []int to *SliceValue. - e1 := NewValue(test.extra).(*SliceValue) + e1 := ValueOf(test.extra) // Test Append. - a0 := NewValue(test.orig).(*SliceValue) + a0 := ValueOf(test.orig) have0 := Append(a0, e0...).Interface().([]int) - if !DeepEqual(have0, want) { - t.Errorf("Append #%d: have %v, want %v", i, have0, want) + if !sameInts(have0, want) { + t.Errorf("Append #%d: have %v, want %v (%p %p)", i, have0, want, test.orig, have0) } // Check that the orig and extra slices were not modified. if len(test.orig) != origLen { @@ -521,9 +511,9 @@ func TestAppend(t *testing.T) { t.Errorf("Append #%d extraLen: have %v, want %v", i, len(test.extra), extraLen) } // Test AppendSlice. - a1 := NewValue(test.orig).(*SliceValue) + a1 := ValueOf(test.orig) have1 := AppendSlice(a1, e1).Interface().([]int) - if !DeepEqual(have1, want) { + if !sameInts(have1, want) { t.Errorf("AppendSlice #%d: have %v, want %v", i, have1, want) } // Check that the orig and extra slices were not modified. @@ -545,8 +535,10 @@ func TestCopy(t *testing.T) { t.Fatalf("b != c before test") } } - aa := NewValue(a).(*SliceValue) - ab := NewValue(b).(*SliceValue) + a1 := a + b1 := b + aa := ValueOf(&a1).Elem() + ab := ValueOf(&b1).Elem() for tocopy := 1; tocopy <= 7; tocopy++ { aa.SetLen(tocopy) Copy(ab, aa) @@ -573,14 +565,35 @@ func TestCopy(t *testing.T) { } } +func TestCopyArray(t *testing.T) { + a := [8]int{1, 2, 3, 4, 10, 9, 8, 7} + b := [11]int{11, 22, 33, 44, 1010, 99, 88, 77, 66, 55, 44} + c := b + aa := ValueOf(&a).Elem() + ab := ValueOf(&b).Elem() + Copy(ab, aa) + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + t.Errorf("(i) a[%d]=%d, b[%d]=%d", i, a[i], i, b[i]) + } + } + for i := len(a); i < len(b); i++ { + if b[i] != c[i] { + t.Errorf("(ii) b[%d]=%d, c[%d]=%d", i, b[i], i, c[i]) + } else { + t.Logf("elem %d is okay\n", i) + } + } +} + func TestBigUnnamedStruct(t *testing.T) { b := struct{ a, b, c, d int64 }{1, 2, 3, 4} - v := NewValue(b) + v := ValueOf(b) b1 := v.Interface().(struct { a, b, c, d int64 }) if b1.a != b.a || b1.b != b.b || b1.c != b.c || b1.d != b.d { - t.Errorf("NewValue(%v).Interface().(*Big) = %v", b, b1) + t.Errorf("ValueOf(%v).Interface().(*Big) = %v", b, b1) } } @@ -590,10 +603,10 @@ type big struct { func TestBigStruct(t *testing.T) { b := big{1, 2, 3, 4, 5} - v := NewValue(b) + v := ValueOf(b) b1 := v.Interface().(big) if b1.a != b.a || b1.b != b.b || b1.c != b.c || b1.d != b.d || b1.e != b.e { - t.Errorf("NewValue(%v).Interface().(big) = %v", b, b1) + t.Errorf("ValueOf(%v).Interface().(big) = %v", b, b1) } } @@ -657,15 +670,15 @@ func TestDeepEqual(t *testing.T) { } } -func TestTypeof(t *testing.T) { +func TestTypeOf(t *testing.T) { for _, test := range deepEqualTests { - v := NewValue(test.a) - if v == nil { + v := ValueOf(test.a) + if !v.IsValid() { continue } - typ := Typeof(test.a) + typ := TypeOf(test.a) if typ != v.Type() { - t.Errorf("Typeof(%v) = %v, but NewValue(%v).Type() = %v", test.a, typ, test.a, v.Type()) + t.Errorf("TypeOf(%v) = %v, but ValueOf(%v).Type() = %v", test.a, typ, test.a, v.Type()) } } } @@ -713,10 +726,28 @@ func TestDeepEqualComplexStructInequality(t *testing.T) { } } +type UnexpT struct { + m map[int]int +} + +func TestDeepEqualUnexportedMap(t *testing.T) { + // Check that DeepEqual can look at unexported fields. + x1 := UnexpT{map[int]int{1: 2}} + x2 := UnexpT{map[int]int{1: 2}} + if !DeepEqual(&x1, &x2) { + t.Error("DeepEqual(x1, x2) = false, want true") + } + + y1 := UnexpT{map[int]int{2: 3}} + if DeepEqual(&x1, &y1) { + t.Error("DeepEqual(x1, y1) = true, want false") + } +} + func check2ndField(x interface{}, offs uintptr, t *testing.T) { - s := NewValue(x).(*StructValue) - f := s.Type().(*StructType).Field(1) + s := ValueOf(x) + f := s.Type().Field(1) if f.Offset != offs { t.Error("mismatched offsets in structure alignment:", f.Offset, offs) } @@ -747,36 +778,22 @@ func TestAlignment(t *testing.T) { check2ndField(x1, uintptr(unsafe.Pointer(&x1.f))-uintptr(unsafe.Pointer(&x1)), t) } -type IsNiller interface { - IsNil() bool -} - func Nil(a interface{}, t *testing.T) { - n := NewValue(a).(*StructValue).Field(0).(IsNiller) + n := ValueOf(a).Field(0) if !n.IsNil() { t.Errorf("%v should be nil", a) } } func NotNil(a interface{}, t *testing.T) { - n := NewValue(a).(*StructValue).Field(0).(IsNiller) + n := ValueOf(a).Field(0) if n.IsNil() { - t.Errorf("value of type %v should not be nil", NewValue(a).Type().String()) + t.Errorf("value of type %v should not be nil", ValueOf(a).Type().String()) } } func TestIsNil(t *testing.T) { - // These do not implement IsNil - doNotNil := []interface{}{int(0), float32(0), struct{ a int }{}} - for _, ts := range doNotNil { - ty := Typeof(ts) - v := MakeZero(ty) - if _, ok := v.(IsNiller); ok { - t.Errorf("%s is nilable; should not be", ts) - } - } - - // These do implement IsNil. + // These implement IsNil. // Wrap in extra struct to hide interface type. doNil := []interface{}{ struct{ x *int }{}, @@ -787,11 +804,9 @@ func TestIsNil(t *testing.T) { struct{ x []string }{}, } for _, ts := range doNil { - ty := Typeof(ts).(*StructType).Field(0).Type - v := MakeZero(ty) - if _, ok := v.(IsNiller); !ok { - t.Errorf("%s %T is not nilable; should be", ts, v) - } + ty := TypeOf(ts).Field(0).Type + v := Zero(ty) + v.IsNil() // panics if not okay to call } // Check the implementations @@ -844,56 +859,28 @@ func TestInterfaceExtraction(t *testing.T) { } s.w = os.Stdout - v := Indirect(NewValue(&s)).(*StructValue).Field(0).Interface() + v := Indirect(ValueOf(&s)).Field(0).Interface() if v != s.w.(interface{}) { t.Error("Interface() on interface: ", v, s.w) } } -func TestInterfaceEditing(t *testing.T) { - // strings are bigger than one word, - // so the interface conversion allocates - // memory to hold a string and puts that - // pointer in the interface. - var i interface{} = "hello" - - // if i pass the interface value by value - // to NewValue, i should get a fresh copy - // of the value. - v := NewValue(i) - - // and setting that copy to "bye" should - // not change the value stored in i. - v.(*StringValue).Set("bye") - if i.(string) != "hello" { - t.Errorf(`Set("bye") changed i to %s`, i.(string)) - } - - // the same should be true of smaller items. - i = 123 - v = NewValue(i) - v.(*IntValue).Set(234) - if i.(int) != 123 { - t.Errorf("Set(234) changed i to %d", i.(int)) - } -} - func TestNilPtrValueSub(t *testing.T) { var pi *int - if pv := NewValue(pi).(*PtrValue); pv.Elem() != nil { - t.Error("NewValue((*int)(nil)).(*PtrValue).Elem() != nil") + if pv := ValueOf(pi); pv.Elem().IsValid() { + t.Error("ValueOf((*int)(nil)).Elem().IsValid()") } } func TestMap(t *testing.T) { m := map[string]int{"a": 1, "b": 2} - mv := NewValue(m).(*MapValue) + mv := ValueOf(m) if n := mv.Len(); n != len(m) { t.Errorf("Len = %d, want %d", n, len(m)) } - keys := mv.Keys() + keys := mv.MapKeys() i := 0 - newmap := MakeMap(mv.Type().(*MapType)) + newmap := MakeMap(mv.Type()) for k, v := range m { // Check that returned Keys match keys in range. // These aren't required to be in the same order, @@ -901,22 +888,22 @@ func TestMap(t *testing.T) { // the test easier. if i >= len(keys) { t.Errorf("Missing key #%d %q", i, k) - } else if kv := keys[i].(*StringValue); kv.Get() != k { - t.Errorf("Keys[%d] = %q, want %q", i, kv.Get(), k) + } else if kv := keys[i]; kv.String() != k { + t.Errorf("Keys[%q] = %d, want %d", i, kv.Int(), k) } i++ // Check that value lookup is correct. - vv := mv.Elem(NewValue(k)) - if vi := vv.(*IntValue).Get(); vi != int64(v) { + vv := mv.MapIndex(ValueOf(k)) + if vi := vv.Int(); vi != int64(v) { t.Errorf("Key %q: have value %d, want %d", k, vi, v) } // Copy into new map. - newmap.SetElem(NewValue(k), NewValue(v)) + newmap.SetMapIndex(ValueOf(k), ValueOf(v)) } - vv := mv.Elem(NewValue("not-present")) - if vv != nil { + vv := mv.MapIndex(ValueOf("not-present")) + if vv.IsValid() { t.Errorf("Invalid key: got non-nil value %s", valueToString(vv)) } @@ -932,14 +919,14 @@ func TestMap(t *testing.T) { } } - newmap.SetElem(NewValue("a"), nil) + newmap.SetMapIndex(ValueOf("a"), Value{}) v, ok := newm["a"] if ok { t.Errorf("newm[\"a\"] = %d after delete", v) } - mv = NewValue(&m).(*PtrValue).Elem().(*MapValue) - mv.Set(nil) + mv = ValueOf(&m).Elem() + mv.Set(Zero(mv.Type())) if m != nil { t.Errorf("mv.Set(nil) failed") } @@ -948,55 +935,55 @@ func TestMap(t *testing.T) { func TestChan(t *testing.T) { for loop := 0; loop < 2; loop++ { var c chan int - var cv *ChanValue + var cv Value // check both ways to allocate channels switch loop { case 1: c = make(chan int, 1) - cv = NewValue(c).(*ChanValue) + cv = ValueOf(c) case 0: - cv = MakeChan(Typeof(c).(*ChanType), 1) + cv = MakeChan(TypeOf(c), 1) c = cv.Interface().(chan int) } // Send - cv.Send(NewValue(2)) + cv.Send(ValueOf(2)) if i := <-c; i != 2 { t.Errorf("reflect Send 2, native recv %d", i) } // Recv c <- 3 - if i, ok := cv.Recv(); i.(*IntValue).Get() != 3 || !ok { - t.Errorf("native send 3, reflect Recv %d, %t", i.(*IntValue).Get(), ok) + if i, ok := cv.Recv(); i.Int() != 3 || !ok { + t.Errorf("native send 3, reflect Recv %d, %t", i.Int(), ok) } // TryRecv fail val, ok := cv.TryRecv() - if val != nil || ok { + if val.IsValid() || ok { t.Errorf("TryRecv on empty chan: %s, %t", valueToString(val), ok) } // TryRecv success c <- 4 val, ok = cv.TryRecv() - if val == nil { + if !val.IsValid() { t.Errorf("TryRecv on ready chan got nil") - } else if i := val.(*IntValue).Get(); i != 4 || !ok { + } else if i := val.Int(); i != 4 || !ok { t.Errorf("native send 4, TryRecv %d, %t", i, ok) } // TrySend fail c <- 100 - ok = cv.TrySend(NewValue(5)) + ok = cv.TrySend(ValueOf(5)) i := <-c if ok { t.Errorf("TrySend on full chan succeeded: value %d", i) } // TrySend success - ok = cv.TrySend(NewValue(6)) + ok = cv.TrySend(ValueOf(6)) if !ok { t.Errorf("TrySend on empty chan failed") } else { @@ -1008,27 +995,27 @@ func TestChan(t *testing.T) { // Close c <- 123 cv.Close() - if i, ok := cv.Recv(); i.(*IntValue).Get() != 123 || !ok { - t.Errorf("send 123 then close; Recv %d, %t", i.(*IntValue).Get(), ok) + if i, ok := cv.Recv(); i.Int() != 123 || !ok { + t.Errorf("send 123 then close; Recv %d, %t", i.Int(), ok) } - if i, ok := cv.Recv(); i.(*IntValue).Get() != 0 || ok { - t.Errorf("after close Recv %d, %t", i.(*IntValue).Get(), ok) + if i, ok := cv.Recv(); i.Int() != 0 || ok { + t.Errorf("after close Recv %d, %t", i.Int(), ok) } } // check creation of unbuffered channel var c chan int - cv := MakeChan(Typeof(c).(*ChanType), 0) + cv := MakeChan(TypeOf(c), 0) c = cv.Interface().(chan int) - if cv.TrySend(NewValue(7)) { + if cv.TrySend(ValueOf(7)) { t.Errorf("TrySend on sync chan succeeded") } - if v, ok := cv.TryRecv(); v != nil || ok { - t.Errorf("TryRecv on sync chan succeeded") + if v, ok := cv.TryRecv(); v.IsValid() || ok { + t.Errorf("TryRecv on sync chan succeeded: isvalid=%v ok=%v", v.IsValid(), ok) } // len/cap - cv = MakeChan(Typeof(c).(*ChanType), 10) + cv = MakeChan(TypeOf(c), 10) c = cv.Interface().(chan int) for i := 0; i < 3; i++ { c <- i @@ -1046,14 +1033,14 @@ func dummy(b byte, c int, d byte) (i byte, j int, k byte) { } func TestFunc(t *testing.T) { - ret := NewValue(dummy).(*FuncValue).Call([]Value{NewValue(byte(10)), NewValue(20), NewValue(byte(30))}) + ret := ValueOf(dummy).Call([]Value{ValueOf(byte(10)), ValueOf(20), ValueOf(byte(30))}) if len(ret) != 3 { t.Fatalf("Call returned %d values, want 3", len(ret)) } - i := ret[0].(*UintValue).Get() - j := ret[1].(*IntValue).Get() - k := ret[2].(*UintValue).Get() + i := byte(ret[0].Uint()) + j := int(ret[1].Int()) + k := byte(ret[2].Uint()) if i != 10 || j != 20 || k != 30 { t.Errorf("Call returned %d, %d, %d; want 10, 20, 30", i, j, k) } @@ -1063,50 +1050,47 @@ type Point struct { x, y int } -func (p Point) Dist(scale int) int { return p.x*p.x*scale + p.y*p.y*scale } +func (p Point) Dist(scale int) int { + // println("Point.Dist", p.x, p.y, scale) + return p.x*p.x*scale + p.y*p.y*scale +} func TestMethod(t *testing.T) { // Non-curried method of type. p := Point{3, 4} - i := Typeof(p).Method(0).Func.Call([]Value{NewValue(p), NewValue(10)})[0].(*IntValue).Get() + i := TypeOf(p).Method(0).Func.Call([]Value{ValueOf(p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Type Method returned %d; want 250", i) } - i = Typeof(&p).Method(0).Func.Call([]Value{NewValue(&p), NewValue(10)})[0].(*IntValue).Get() + i = TypeOf(&p).Method(0).Func.Call([]Value{ValueOf(&p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Pointer Type Method returned %d; want 250", i) } // Curried method of value. - i = NewValue(p).Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() + i = ValueOf(p).Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } // Curried method of pointer. - i = NewValue(&p).Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() - if i != 250 { - t.Errorf("Value Method returned %d; want 250", i) - } - - // Curried method of pointer to value. - i = NewValue(p).Addr().Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() + i = ValueOf(&p).Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } // Curried method of interface value. // Have to wrap interface value in a struct to get at it. - // Passing it to NewValue directly would + // Passing it to ValueOf directly would // access the underlying Point, not the interface. var s = struct { - x interface { + X interface { Dist(int) int } }{p} - pv := NewValue(s).(*StructValue).Field(0) - i = pv.Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() + pv := ValueOf(s).Field(0) + i = pv.Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } @@ -1121,19 +1105,19 @@ func TestInterfaceSet(t *testing.T) { Dist(int) int } } - sv := NewValue(&s).(*PtrValue).Elem().(*StructValue) - sv.Field(0).(*InterfaceValue).Set(NewValue(p)) + sv := ValueOf(&s).Elem() + sv.Field(0).Set(ValueOf(p)) if q := s.I.(*Point); q != p { t.Errorf("i: have %p want %p", q, p) } - pv := sv.Field(1).(*InterfaceValue) - pv.Set(NewValue(p)) + pv := sv.Field(1) + pv.Set(ValueOf(p)) if q := s.P.(*Point); q != p { t.Errorf("i: have %p want %p", q, p) } - i := pv.Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() + i := pv.Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } @@ -1148,7 +1132,7 @@ func TestAnonymousFields(t *testing.T) { var field StructField var ok bool var t1 T1 - type1 := Typeof(t1).(*StructType) + type1 := TypeOf(t1) if field, ok = type1.FieldByName("int"); !ok { t.Error("no field 'int'") } @@ -1232,7 +1216,7 @@ var fieldTests = []FTest{ func TestFieldByIndex(t *testing.T) { for _, test := range fieldTests { - s := Typeof(test.s).(*StructType) + s := TypeOf(test.s) f := s.FieldByIndex(test.index) if f.Name != "" { if test.index != nil { @@ -1247,8 +1231,8 @@ func TestFieldByIndex(t *testing.T) { } if test.value != 0 { - v := NewValue(test.s).(*StructValue).FieldByIndex(test.index) - if v != nil { + v := ValueOf(test.s).FieldByIndex(test.index) + if v.IsValid() { if x, ok := v.Interface().(int); ok { if x != test.value { t.Errorf("%s%v is %d; want %d", s.Name(), test.index, x, test.value) @@ -1265,7 +1249,7 @@ func TestFieldByIndex(t *testing.T) { func TestFieldByName(t *testing.T) { for _, test := range fieldTests { - s := Typeof(test.s).(*StructType) + s := TypeOf(test.s) f, found := s.FieldByName(test.name) if found { if test.index != nil { @@ -1287,8 +1271,8 @@ func TestFieldByName(t *testing.T) { } if test.value != 0 { - v := NewValue(test.s).(*StructValue).FieldByName(test.name) - if v != nil { + v := ValueOf(test.s).FieldByName(test.name) + if v.IsValid() { if x, ok := v.Interface().(int); ok { if x != test.value { t.Errorf("%s.%s is %d; want %d", s.Name(), test.name, x, test.value) @@ -1304,19 +1288,19 @@ func TestFieldByName(t *testing.T) { } func TestImportPath(t *testing.T) { - if path := Typeof(vector.Vector{}).PkgPath(); path != "libgo_container.vector" { - t.Errorf("Typeof(vector.Vector{}).PkgPath() = %q, want \"libgo_container.vector\"", path) + if path := TypeOf(vector.Vector{}).PkgPath(); path != "libgo_container.vector" { + t.Errorf("TypeOf(vector.Vector{}).PkgPath() = %q, want \"libgo_container.vector\"", path) } } func TestDotDotDot(t *testing.T) { // Test example from FuncType.DotDotDot documentation. var f func(x int, y ...float64) - typ := Typeof(f).(*FuncType) - if typ.NumIn() == 2 && typ.In(0) == Typeof(int(0)) { - sl, ok := typ.In(1).(*SliceType) - if ok { - if sl.Elem() == Typeof(0.0) { + typ := TypeOf(f) + if typ.NumIn() == 2 && typ.In(0) == TypeOf(int(0)) { + sl := typ.In(1) + if sl.Kind() == Slice { + if sl.Elem() == TypeOf(0.0) { // ok return } @@ -1345,50 +1329,50 @@ func (*inner) m() {} func (*outer) m() {} func TestNestedMethods(t *testing.T) { - typ := Typeof((*outer)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Get() != NewValue((*outer).m).(*FuncValue).Get() { + typ := TypeOf((*outer)(nil)) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).m).Pointer() { t.Errorf("Wrong method table for outer: (m=%p)", (*outer).m) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) - t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Get()) + t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Pointer()) } } } -type innerInt struct { - x int +type InnerInt struct { + X int } -type outerInt struct { - y int - innerInt +type OuterInt struct { + Y int + InnerInt } -func (i *innerInt) m() int { - return i.x +func (i *InnerInt) M() int { + return i.X } func TestEmbeddedMethods(t *testing.T) { - typ := Typeof((*outerInt)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Get() != NewValue((*outerInt).m).(*FuncValue).Get() { - t.Errorf("Wrong method table for outerInt: (m=%p)", (*outerInt).m) + typ := TypeOf((*OuterInt)(nil)) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*OuterInt).M).Pointer() { + t.Errorf("Wrong method table for OuterInt: (m=%p)", (*OuterInt).M) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) - t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Get()) + t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Pointer()) } } - i := &innerInt{3} - if v := NewValue(i).Method(0).Call(nil)[0].(*IntValue).Get(); v != 3 { - t.Errorf("i.m() = %d, want 3", v) + i := &InnerInt{3} + if v := ValueOf(i).Method(0).Call(nil)[0].Int(); v != 3 { + t.Errorf("i.M() = %d, want 3", v) } - o := &outerInt{1, innerInt{2}} - if v := NewValue(o).Method(0).Call(nil)[0].(*IntValue).Get(); v != 2 { - t.Errorf("i.m() = %d, want 2", v) + o := &OuterInt{1, InnerInt{2}} + if v := ValueOf(o).Method(0).Call(nil)[0].Int(); v != 2 { + t.Errorf("i.M() = %d, want 2", v) } - f := (*outerInt).m + f := (*OuterInt).M if v := f(o); v != 2 { t.Errorf("f(o) = %d, want 2", v) } @@ -1397,15 +1381,15 @@ func TestEmbeddedMethods(t *testing.T) { func TestPtrTo(t *testing.T) { var i int - typ := Typeof(i) + typ := TypeOf(i) for i = 0; i < 100; i++ { typ = PtrTo(typ) } for i = 0; i < 100; i++ { - typ = typ.(*PtrType).Elem() + typ = typ.Elem() } - if typ != Typeof(i) { - t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, Typeof(i)) + if typ != TypeOf(i) { + t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, TypeOf(i)) } } @@ -1414,38 +1398,40 @@ func TestAddr(t *testing.T) { X, Y int } - v := NewValue(&p) - v = v.(*PtrValue).Elem() + v := ValueOf(&p) + v = v.Elem() v = v.Addr() - v = v.(*PtrValue).Elem() - v = v.(*StructValue).Field(0) - v.(*IntValue).Set(2) + v = v.Elem() + v = v.Field(0) + v.SetInt(2) if p.X != 2 { t.Errorf("Addr.Elem.Set failed to set value") } - // Again but take address of the NewValue value. + // Again but take address of the ValueOf value. // Exercises generation of PtrTypes not present in the binary. - v = NewValue(&p) + q := &p + v = ValueOf(&q).Elem() v = v.Addr() - v = v.(*PtrValue).Elem() - v = v.(*PtrValue).Elem() + v = v.Elem() + v = v.Elem() v = v.Addr() - v = v.(*PtrValue).Elem() - v = v.(*StructValue).Field(0) - v.(*IntValue).Set(3) + v = v.Elem() + v = v.Field(0) + v.SetInt(3) if p.X != 3 { t.Errorf("Addr.Elem.Set failed to set value") } // Starting without pointer we should get changed value // in interface. - v = NewValue(p) + qq := p + v = ValueOf(&qq).Elem() v0 := v v = v.Addr() - v = v.(*PtrValue).Elem() - v = v.(*StructValue).Field(0) - v.(*IntValue).Set(4) + v = v.Elem() + v = v.Field(0) + v.SetInt(4) if p.X != 3 { // should be unchanged from last time t.Errorf("somehow value Set changed original p") } @@ -1456,3 +1442,71 @@ func TestAddr(t *testing.T) { t.Errorf("Addr.Elem.Set valued to set value in top value") } } + +/* gccgo does do allocations here. + +func noAlloc(t *testing.T, n int, f func(int)) { + // once to prime everything + f(-1) + runtime.MemStats.Mallocs = 0 + + for j := 0; j < n; j++ { + f(j) + } + if runtime.MemStats.Mallocs != 0 { + t.Fatalf("%d mallocs after %d iterations", runtime.MemStats.Mallocs, n) + } +} + +func TestAllocations(t *testing.T) { + noAlloc(t, 100, func(j int) { + var i interface{} + var v Value + i = 42 + j + v = ValueOf(i) + if int(v.Int()) != 42+j { + panic("wrong int") + } + }) +} + +*/ + +func TestSmallNegativeInt(t *testing.T) { + i := int16(-1) + v := ValueOf(i) + if v.Int() != -1 { + t.Errorf("int16(-1).Int() returned %v", v.Int()) + } +} + +func TestSlice(t *testing.T) { + xs := []int{1, 2, 3, 4, 5, 6, 7, 8} + v := ValueOf(xs).Slice(3, 5).Interface().([]int) + if len(v) != 2 || v[0] != 4 || v[1] != 5 { + t.Errorf("xs.Slice(3, 5) = %v", v) + } + + xa := [7]int{10, 20, 30, 40, 50, 60, 70} + v = ValueOf(&xa).Elem().Slice(2, 5).Interface().([]int) + if len(v) != 3 || v[0] != 30 || v[1] != 40 || v[2] != 50 { + t.Errorf("xa.Slice(2, 5) = %v", v) + } +} + +func TestVariadic(t *testing.T) { + var b bytes.Buffer + V := ValueOf + + b.Reset() + V(fmt.Fprintf).Call([]Value{V(&b), V("%s, %d world"), V("hello"), V(42)}) + if b.String() != "hello, 42 world" { + t.Errorf("after Fprintf Call: %q != %q", b.String(), "hello 42 world") + } + + b.Reset() + V(fmt.Fprintf).CallSlice([]Value{V(&b), V("%s, %d world"), V([]interface{}{"hello", 42})}) + if b.String() != "hello, 42 world" { + t.Errorf("after Fprintf CallSlice: %q != %q", b.String(), "hello 42 world") + } +} diff --git a/libgo/go/reflect/deepequal.go b/libgo/go/reflect/deepequal.go index c9beec50666..a483135b017 100644 --- a/libgo/go/reflect/deepequal.go +++ b/libgo/go/reflect/deepequal.go @@ -6,7 +6,6 @@ package reflect - // During deepValueEqual, must keep track of checks that are // in progress. The comparison algorithm assumes that all // checks in progress are true when it reencounters them. @@ -21,9 +20,9 @@ type visit struct { // Tests for deep equality using reflected types. The map argument tracks // comparisons that have already been seen, which allows short circuiting on // recursive types. -func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { - if v1 == nil || v2 == nil { - return v1 == v2 +func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) (b bool) { + if !v1.IsValid() || !v2.IsValid() { + return v1.IsValid() == v2.IsValid() } if v1.Type() != v2.Type() { return false @@ -31,82 +30,74 @@ func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { // if depth > 10 { panic("deepValueEqual") } // for debugging - addr1 := v1.UnsafeAddr() - addr2 := v2.UnsafeAddr() - if addr1 > addr2 { - // Canonicalize order to reduce number of entries in visited. - addr1, addr2 = addr2, addr1 - } - - // Short circuit if references are identical ... - if addr1 == addr2 { - return true - } + if v1.CanAddr() && v2.CanAddr() { + addr1 := v1.UnsafeAddr() + addr2 := v2.UnsafeAddr() + if addr1 > addr2 { + // Canonicalize order to reduce number of entries in visited. + addr1, addr2 = addr2, addr1 + } - // ... or already seen - h := 17*addr1 + addr2 - seen := visited[h] - typ := v1.Type() - for p := seen; p != nil; p = p.next { - if p.a1 == addr1 && p.a2 == addr2 && p.typ == typ { + // Short circuit if references are identical ... + if addr1 == addr2 { return true } - } - // Remember for later. - visited[h] = &visit{addr1, addr2, typ, seen} + // ... or already seen + h := 17*addr1 + addr2 + seen := visited[h] + typ := v1.Type() + for p := seen; p != nil; p = p.next { + if p.a1 == addr1 && p.a2 == addr2 && p.typ == typ { + return true + } + } + + // Remember for later. + visited[h] = &visit{addr1, addr2, typ, seen} + } - switch v := v1.(type) { - case *ArrayValue: - arr1 := v - arr2 := v2.(*ArrayValue) - if arr1.Len() != arr2.Len() { + switch v1.Kind() { + case Array: + if v1.Len() != v2.Len() { return false } - for i := 0; i < arr1.Len(); i++ { - if !deepValueEqual(arr1.Elem(i), arr2.Elem(i), visited, depth+1) { + for i := 0; i < v1.Len(); i++ { + if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { return false } } return true - case *SliceValue: - arr1 := v - arr2 := v2.(*SliceValue) - if arr1.Len() != arr2.Len() { + case Slice: + if v1.Len() != v2.Len() { return false } - for i := 0; i < arr1.Len(); i++ { - if !deepValueEqual(arr1.Elem(i), arr2.Elem(i), visited, depth+1) { + for i := 0; i < v1.Len(); i++ { + if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { return false } } return true - case *InterfaceValue: - i1 := v.Interface() - i2 := v2.Interface() - if i1 == nil || i2 == nil { - return i1 == i2 + case Interface: + if v1.IsNil() || v2.IsNil() { + return v1.IsNil() == v2.IsNil() } - return deepValueEqual(NewValue(i1), NewValue(i2), visited, depth+1) - case *PtrValue: - return deepValueEqual(v.Elem(), v2.(*PtrValue).Elem(), visited, depth+1) - case *StructValue: - struct1 := v - struct2 := v2.(*StructValue) - for i, n := 0, v.NumField(); i < n; i++ { - if !deepValueEqual(struct1.Field(i), struct2.Field(i), visited, depth+1) { + return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) + case Ptr: + return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) + case Struct: + for i, n := 0, v1.NumField(); i < n; i++ { + if !deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) { return false } } return true - case *MapValue: - map1 := v - map2 := v2.(*MapValue) - if map1.Len() != map2.Len() { + case Map: + if v1.Len() != v2.Len() { return false } - for _, k := range map1.Keys() { - if !deepValueEqual(map1.Elem(k), map2.Elem(k), visited, depth+1) { + for _, k := range v1.MapKeys() { + if !deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { return false } } @@ -126,8 +117,8 @@ func DeepEqual(a1, a2 interface{}) bool { if a1 == nil || a2 == nil { return a1 == a2 } - v1 := NewValue(a1) - v2 := NewValue(a2) + v1 := ValueOf(a1) + v2 := ValueOf(a2) if v1.Type() != v2.Type() { return false } diff --git a/libgo/go/reflect/set_test.go b/libgo/go/reflect/set_test.go new file mode 100644 index 00000000000..8135a4cd148 --- /dev/null +++ b/libgo/go/reflect/set_test.go @@ -0,0 +1,211 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflect_test + +import ( + "bytes" + "go/ast" + "io" + . "reflect" + "testing" + "unsafe" +) + +type MyBuffer bytes.Buffer + +func TestImplicitMapConversion(t *testing.T) { + // Test implicit conversions in MapIndex and SetMapIndex. + { + // direct + m := make(map[int]int) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#1 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#1 MapIndex(1) = %d", n) + } + } + { + // convert interface key + m := make(map[interface{}]int) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#2 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#2 MapIndex(1) = %d", n) + } + } + { + // convert interface value + m := make(map[int]interface{}) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#3 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#3 MapIndex(1) = %d", n) + } + } + { + // convert both interface key and interface value + m := make(map[interface{}]interface{}) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#4 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#4 MapIndex(1) = %d", n) + } + } + { + // convert both, with non-empty interfaces + m := make(map[io.Reader]io.Writer) + mv := ValueOf(m) + b1 := new(bytes.Buffer) + b2 := new(bytes.Buffer) + mv.SetMapIndex(ValueOf(b1), ValueOf(b2)) + x, ok := m[b1] + if x != b2 { + t.Errorf("#5 after SetMapIndex(b1, b2): %p (!= %p), %t (map=%v)", x, b2, ok, m) + } + if p := mv.MapIndex(ValueOf(b1)).Elem().Pointer(); p != uintptr(unsafe.Pointer(b2)) { + t.Errorf("#5 MapIndex(b1) = %p want %p", p, b2) + } + } + { + // convert channel direction + m := make(map[<-chan int]chan int) + mv := ValueOf(m) + c1 := make(chan int) + c2 := make(chan int) + mv.SetMapIndex(ValueOf(c1), ValueOf(c2)) + x, ok := m[c1] + if x != c2 { + t.Errorf("#6 after SetMapIndex(c1, c2): %p (!= %p), %t (map=%v)", x, c2, ok, m) + } + if p := mv.MapIndex(ValueOf(c1)).Pointer(); p != ValueOf(c2).Pointer() { + t.Errorf("#6 MapIndex(c1) = %p want %p", p, c2) + } + } + { + // convert identical underlying types + // TODO(rsc): Should be able to define MyBuffer here. + // 6l prints very strange messages about .this.Bytes etc + // when we do that though, so MyBuffer is defined + // at top level. + m := make(map[*MyBuffer]*bytes.Buffer) + mv := ValueOf(m) + b1 := new(MyBuffer) + b2 := new(bytes.Buffer) + mv.SetMapIndex(ValueOf(b1), ValueOf(b2)) + x, ok := m[b1] + if x != b2 { + t.Errorf("#7 after SetMapIndex(b1, b2): %p (!= %p), %t (map=%v)", x, b2, ok, m) + } + if p := mv.MapIndex(ValueOf(b1)).Pointer(); p != uintptr(unsafe.Pointer(b2)) { + t.Errorf("#7 MapIndex(b1) = %p want %p", p, b2) + } + } + +} + +func TestImplicitSetConversion(t *testing.T) { + // Assume TestImplicitMapConversion covered the basics. + // Just make sure conversions are being applied at all. + var r io.Reader + b := new(bytes.Buffer) + rv := ValueOf(&r).Elem() + rv.Set(ValueOf(b)) + if r != b { + t.Errorf("after Set: r=%T(%v)", r, r) + } +} + +func TestImplicitSendConversion(t *testing.T) { + c := make(chan io.Reader, 10) + b := new(bytes.Buffer) + ValueOf(c).Send(ValueOf(b)) + if bb := <-c; bb != b { + t.Errorf("Received %p != %p", bb, b) + } +} + +func TestImplicitCallConversion(t *testing.T) { + // Arguments must be assignable to parameter types. + fv := ValueOf(io.WriteString) + b := new(bytes.Buffer) + fv.Call([]Value{ValueOf(b), ValueOf("hello world")}) + if b.String() != "hello world" { + t.Errorf("After call: string=%q want %q", b.String(), "hello world") + } +} + +func TestImplicitAppendConversion(t *testing.T) { + // Arguments must be assignable to the slice's element type. + s := []io.Reader{} + sv := ValueOf(&s).Elem() + b := new(bytes.Buffer) + sv.Set(Append(sv, ValueOf(b))) + if len(s) != 1 || s[0] != b { + t.Errorf("after append: s=%v want [%p]", s, b) + } +} + +var implementsTests = []struct { + x interface{} + t interface{} + b bool +}{ + {new(*bytes.Buffer), new(io.Reader), true}, + {new(bytes.Buffer), new(io.Reader), false}, + {new(*bytes.Buffer), new(io.ReaderAt), false}, + {new(*ast.Ident), new(ast.Expr), true}, +} + +func TestImplements(t *testing.T) { + for _, tt := range implementsTests { + xv := TypeOf(tt.x).Elem() + xt := TypeOf(tt.t).Elem() + if b := xv.Implements(xt); b != tt.b { + t.Errorf("(%s).Implements(%s) = %v, want %v", xv.String(), xt.String(), b, tt.b) + } + } +} + +var assignableTests = []struct { + x interface{} + t interface{} + b bool +}{ + {new(chan int), new(<-chan int), true}, + {new(<-chan int), new(chan int), false}, + {new(*int), new(IntPtr), true}, + {new(IntPtr), new(*int), true}, + {new(IntPtr), new(IntPtr1), false}, + // test runs implementsTests too +} + +type IntPtr *int +type IntPtr1 *int + +func TestAssignableTo(t *testing.T) { + for _, tt := range append(assignableTests, implementsTests...) { + xv := TypeOf(tt.x).Elem() + xt := TypeOf(tt.t).Elem() + if b := xv.AssignableTo(xt); b != tt.b { + t.Errorf("(%s).AssignableTo(%s) = %v, want %v", xv.String(), xt.String(), b, tt.b) + } + } +} diff --git a/libgo/go/reflect/tostring_test.go b/libgo/go/reflect/tostring_test.go index a1487fdd2fe..5f5c52b778a 100644 --- a/libgo/go/reflect/tostring_test.go +++ b/libgo/go/reflect/tostring_test.go @@ -17,29 +17,29 @@ import ( // For debugging only. func valueToString(val Value) string { var str string - if val == nil { - return "" + if !val.IsValid() { + return "" } typ := val.Type() - switch val := val.(type) { - case *IntValue: - return strconv.Itoa64(val.Get()) - case *UintValue: - return strconv.Uitoa64(val.Get()) - case *FloatValue: - return strconv.Ftoa64(float64(val.Get()), 'g', -1) - case *ComplexValue: - c := val.Get() - return strconv.Ftoa64(float64(real(c)), 'g', -1) + "+" + strconv.Ftoa64(float64(imag(c)), 'g', -1) + "i" - case *StringValue: - return val.Get() - case *BoolValue: - if val.Get() { + switch val.Kind() { + case Int, Int8, Int16, Int32, Int64: + return strconv.Itoa64(val.Int()) + case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr: + return strconv.Uitoa64(val.Uint()) + case Float32, Float64: + return strconv.Ftoa64(val.Float(), 'g', -1) + case Complex64, Complex128: + c := val.Complex() + return strconv.Ftoa64(real(c), 'g', -1) + "+" + strconv.Ftoa64(imag(c), 'g', -1) + "i" + case String: + return val.String() + case Bool: + if val.Bool() { return "true" } else { return "false" } - case *PtrValue: + case Ptr: v := val str = typ.String() + "(" if v.IsNil() { @@ -49,7 +49,7 @@ func valueToString(val Value) string { } str += ")" return str - case ArrayOrSliceValue: + case Array, Slice: v := val str += typ.String() str += "{" @@ -57,22 +57,22 @@ func valueToString(val Value) string { if i > 0 { str += ", " } - str += valueToString(v.Elem(i)) + str += valueToString(v.Index(i)) } str += "}" return str - case *MapValue: - t := typ.(*MapType) + case Map: + t := typ str = t.String() str += "{" str += "" str += "}" return str - case *ChanValue: + case Chan: str = typ.String() return str - case *StructValue: - t := typ.(*StructType) + case Struct: + t := typ v := val str += t.String() str += "{" @@ -84,11 +84,11 @@ func valueToString(val Value) string { } str += "}" return str - case *InterfaceValue: + case Interface: return typ.String() + "(" + valueToString(val.Elem()) + ")" - case *FuncValue: + case Func: v := val - return typ.String() + "(" + strconv.Itoa64(int64(v.Get())) + ")" + return typ.String() + "(" + strconv.Uitoa64(uint64(v.Pointer())) + ")" default: panic("valueToString: can't print type " + typ.String()) } diff --git a/libgo/go/reflect/type.go b/libgo/go/reflect/type.go index fb798802815..30bf54a1f18 100644 --- a/libgo/go/reflect/type.go +++ b/libgo/go/reflect/type.go @@ -2,17 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The reflect package implements run-time reflection, allowing a program to -// manipulate objects with arbitrary types. The typical use is to take a -// value with static type interface{} and extract its dynamic type -// information by calling Typeof, which returns an object with interface -// type Type. That contains a pointer to a struct of type *StructType, -// *IntType, etc. representing the details of the underlying type. A type -// switch or type assertion can reveal which. +// Package reflect implements run-time reflection, allowing a program to +// manipulate objects with arbitrary types. The typical use is to take a value +// with static type interface{} and extract its dynamic type information by +// calling TypeOf, which returns a Type. // -// A call to NewValue creates a Value representing the run-time data; it -// contains a *StructValue, *IntValue, etc. MakeZero takes a Type and -// returns a Value representing a zero value for that type. +// A call to ValueOf returns a Value representing the run-time data. +// Zero takes a Type and returns a Value representing a zero value +// for that type. package reflect import ( @@ -22,6 +19,194 @@ import ( "unsafe" ) +// Type is the representation of a Go type. +// +// Not all methods apply to all kinds of types. Restrictions, +// if any, are noted in the documentation for each method. +// Use the Kind method to find out the kind of type before +// calling kind-specific methods. Calling a method +// inappropriate to the kind of type causes a run-time panic. +type Type interface { + // Methods applicable to all types. + + // Align returns the alignment in bytes of a value of + // this type when allocated in memory. + Align() int + + // FieldAlign returns the alignment in bytes of a value of + // this type when used as a field in a struct. + FieldAlign() int + + // Method returns the i'th method in the type's method set. + // It panics if i is not in the range [0, NumMethod()). + // + // For a non-interface type T or *T, the returned Method's Type and Func + // fields describe a function whose first argument is the receiver. + // + // For an interface type, the returned Method's Type field gives the + // method signature, without a receiver, and the Func field is nil. + Method(int) Method + + // NumMethod returns the number of methods in the type's method set. + NumMethod() int + + // Name returns the type's name within its package. + // It returns an empty string for unnamed types. + Name() string + + // PkgPath returns the type's package path. + // The package path is a full package import path like "container/vector". + // PkgPath returns an empty string for unnamed types. + PkgPath() string + + // Size returns the number of bytes needed to store + // a value of the given type; it is analogous to unsafe.Sizeof. + Size() uintptr + + // String returns a string representation of the type. + // The string representation may use shortened package names + // (e.g., vector instead of "container/vector") and is not + // guaranteed to be unique among types. To test for equality, + // compare the Types directly. + String() string + + // Kind returns the specific kind of this type. + Kind() Kind + + // Implements returns true if the type implements the interface type u. + Implements(u Type) bool + + // AssignableTo returns true if a value of the type is assignable to type u. + AssignableTo(u Type) bool + + // Methods applicable only to some types, depending on Kind. + // The methods allowed for each kind are: + // + // Int*, Uint*, Float*, Complex*: Bits + // Array: Elem, Len + // Chan: ChanDir, Elem + // Func: In, NumIn, Out, NumOut, IsVariadic. + // Map: Key, Elem + // Ptr: Elem + // Slice: Elem + // Struct: Field, FieldByIndex, FieldByName, FieldByNameFunc, NumField + + // Bits returns the size of the type in bits. + // It panics if the type's Kind is not one of the + // sized or unsized Int, Uint, Float, or Complex kinds. + Bits() int + + // ChanDir returns a channel type's direction. + // It panics if the type's Kind is not Chan. + ChanDir() ChanDir + + // IsVariadic returns true if a function type's final input parameter + // is a "..." parameter. If so, t.In(t.NumIn() - 1) returns the parameter's + // implicit actual type []T. + // + // For concreteness, if t represents func(x int, y ... float), then + // + // t.NumIn() == 2 + // t.In(0) is the reflect.Type for "int" + // t.In(1) is the reflect.Type for "[]float" + // t.IsVariadic() == true + // + // IsVariadic panics if the type's Kind is not Func. + IsVariadic() bool + + // Elem returns a type's element type. + // It panics if the type's Kind is not Array, Chan, Map, Ptr, or Slice. + Elem() Type + + // Field returns a struct type's i'th field. + // It panics if the type's Kind is not Struct. + // It panics if i is not in the range [0, NumField()). + Field(i int) StructField + + // FieldByIndex returns the nested field corresponding + // to the index sequence. It is equivalent to calling Field + // successively for each index i. + // It panics if the type's Kind is not Struct. + FieldByIndex(index []int) StructField + + // FieldByName returns the struct field with the given name + // and a boolean indicating if the field was found. + FieldByName(name string) (StructField, bool) + + // FieldByNameFunc returns the first struct field with a name + // that satisfies the match function and a boolean indicating if + // the field was found. + FieldByNameFunc(match func(string) bool) (StructField, bool) + + // In returns the type of a function type's i'th input parameter. + // It panics if the type's Kind is not Func. + // It panics if i is not in the range [0, NumIn()). + In(i int) Type + + // Key returns a map type's key type. + // It panics if the type's Kind is not Map. + Key() Type + + // Len returns an array type's length. + // It panics if the type's Kind is not Array. + Len() int + + // NumField returns a struct type's field count. + // It panics if the type's Kind is not Struct. + NumField() int + + // NumIn returns a function type's input parameter count. + // It panics if the type's Kind is not Func. + NumIn() int + + // NumOut returns a function type's output parameter count. + // It panics if the type's Kind is not Func. + NumOut() int + + // Out returns the type of a function type's i'th output parameter. + // It panics if the type's Kind is not Func. + // It panics if i is not in the range [0, NumOut()). + Out(i int) Type + + runtimeType() *runtime.Type + common() *commonType + uncommon() *uncommonType +} + +// A Kind represents the specific kind of type that a Type represents. +// The zero Kind is not a valid kind. +type Kind uint8 + +const ( + Invalid Kind = iota + Bool + Int + Int8 + Int16 + Int32 + Int64 + Uint + Uint8 + Uint16 + Uint32 + Uint64 + Uintptr + Float32 + Float64 + Complex64 + Complex128 + Array + Chan + Func + Interface + Map + Ptr + Slice + String + Struct + UnsafePointer +) + /* * Copy of data structures from ../runtime/type.go. * For comments, see the ones in that file. @@ -67,48 +252,6 @@ type uncommonType struct { methods []method } -// BoolType represents a boolean type. -type BoolType struct { - commonType "bool" -} - -// FloatType represents a float type. -type FloatType struct { - commonType "float" -} - -// ComplexType represents a complex type. -type ComplexType struct { - commonType "complex" -} - -// IntType represents a signed integer type. -type IntType struct { - commonType "int" -} - -// UintType represents a uint type. -type UintType struct { - commonType "uint" -} - -// StringType represents a string type. -type StringType struct { - commonType "string" -} - -// UnsafePointerType represents an unsafe.Pointer type. -type UnsafePointerType struct { - commonType "unsafe.Pointer" -} - -// ArrayType represents a fixed array type. -type ArrayType struct { - commonType "array" - elem *runtime.Type - len uintptr -} - // ChanDir represents a channel type's direction. type ChanDir int @@ -118,57 +261,62 @@ const ( BothDir = RecvDir | SendDir ) -// ChanType represents a channel type. -type ChanType struct { + +// arrayType represents a fixed array type. +type arrayType struct { + commonType "array" + elem *runtime.Type + slice *runtime.Type + len uintptr +} + +// chanType represents a channel type. +type chanType struct { commonType "chan" elem *runtime.Type dir uintptr } -// FuncType represents a function type. -type FuncType struct { +// funcType represents a function type. +type funcType struct { commonType "func" dotdotdot bool in []*runtime.Type out []*runtime.Type } -// Method on interface type +// imethod represents a method on an interface type type imethod struct { name *string pkgPath *string typ *runtime.Type } -// InterfaceType represents an interface type. -type InterfaceType struct { +// interfaceType represents an interface type. +type interfaceType struct { commonType "interface" methods []imethod } -// MapType represents a map type. -type MapType struct { +// mapType represents a map type. +type mapType struct { commonType "map" key *runtime.Type elem *runtime.Type } -// PtrType represents a pointer type. -type PtrType struct { +// ptrType represents a pointer type. +type ptrType struct { commonType "ptr" elem *runtime.Type } -// SliceType represents a slice type. -type SliceType struct { +// sliceType represents a slice type. +type sliceType struct { commonType "slice" elem *runtime.Type } -// arrayOrSliceType is an unexported method that guarantees only -// arrays and slices implement ArrayOrSliceType. -func (*SliceType) arrayOrSliceType() {} - // Struct field type structField struct { name *string @@ -178,8 +326,8 @@ type structField struct { offset uintptr } -// StructType represents a struct type. -type StructType struct { +// structType represents a struct type. +type structType struct { commonType "struct" fields []structField } @@ -194,106 +342,10 @@ type StructType struct { type Method struct { PkgPath string // empty for uppercase Name Name string - Type *FuncType - Func *FuncValue + Type Type + Func Value } -// Type is the runtime representation of a Go type. -// Every type implements the methods listed here. -// Some types implement additional interfaces; -// use a type switch to find out what kind of type a Type is. -// Each type in a program has a unique Type, so == on Types -// corresponds to Go's type equality. -type Type interface { - // PkgPath returns the type's package path. - // The package path is a full package import path like "container/vector". - // PkgPath returns an empty string for unnamed types. - PkgPath() string - - // Name returns the type's name within its package. - // Name returns an empty string for unnamed types. - Name() string - - // String returns a string representation of the type. - // The string representation may use shortened package names - // (e.g., vector instead of "container/vector") and is not - // guaranteed to be unique among types. To test for equality, - // compare the Types directly. - String() string - - // Size returns the number of bytes needed to store - // a value of the given type; it is analogous to unsafe.Sizeof. - Size() uintptr - - // Bits returns the size of the type in bits. - // It is intended for use with numeric types and may overflow - // when used for composite types. - Bits() int - - // Align returns the alignment of a value of this type - // when allocated in memory. - Align() int - - // FieldAlign returns the alignment of a value of this type - // when used as a field in a struct. - FieldAlign() int - - // Kind returns the specific kind of this type. - Kind() Kind - - // Method returns the i'th method in the type's method set. - // - // For a non-interface type T or *T, the returned Method's Type and Func - // fields describe a function whose first argument is the receiver. - // - // For an interface type, the returned Method's Type field gives the - // method signature, without a receiver, and the Func field is nil. - Method(int) Method - - // NumMethods returns the number of methods in the type's method set. - NumMethod() int - - common() *commonType - uncommon() *uncommonType -} - -// A Kind represents the specific kind of type that a Type represents. -// For numeric types, the Kind gives more information than the Type's -// dynamic type. For example, the Type of a float32 is FloatType, but -// the Kind is Float32. -// -// The zero Kind is not a valid kind. -type Kind uint8 - -const ( - Bool Kind = 1 + iota - Int - Int8 - Int16 - Int32 - Int64 - Uint - Uint8 - Uint16 - Uint32 - Uint64 - Uintptr - Float32 - Float64 - Complex64 - Complex128 - Array - Chan - Func - Interface - Map - Ptr - Slice - String - Struct - UnsafePointer -) - // High bit says whether type has // embedded pointers,to help garbage collector. const kindMask = 0x7f @@ -306,6 +358,7 @@ func (k Kind) String() string { } var kindNames = []string{ + Invalid: "invalid", Bool: "bool", Int: "int", Int8: "int8", @@ -352,11 +405,27 @@ func (t *uncommonType) Name() string { return *t.name } +func (t *commonType) toType() Type { + if t == nil { + return nil + } + return canonicalize(t) +} + func (t *commonType) String() string { return *t.string } func (t *commonType) Size() uintptr { return t.size } -func (t *commonType) Bits() int { return int(t.size * 8) } +func (t *commonType) Bits() int { + if t == nil { + panic("reflect: Bits of nil Type") + } + k := t.Kind() + if k < Int || k > Complex128 { + panic("reflect: Bits of non-arithmetic Type " + t.String()) + } + return int(t.size) * 8 +} func (t *commonType) Align() int { return int(t.align) } @@ -374,12 +443,15 @@ func (t *uncommonType) Method(i int) (m Method) { if p.name != nil { m.Name = *p.name } + flag := uint32(0) if p.pkgPath != nil { m.PkgPath = *p.pkgPath + flag |= flagRO } - m.Type = runtimeToType(p.typ).(*FuncType) - fn := p.tfn - m.Func = &FuncValue{value: value{m.Type, addr(&fn), canSet}} + m.Type = toType(p.typ) + x := new(unsafe.Pointer) + *x = p.tfn + m.Func = valueFromIword(flag, m.Type, iword(uintptr(unsafe.Pointer(x)))) return } @@ -393,79 +465,169 @@ func (t *uncommonType) NumMethod() int { // TODO(rsc): 6g supplies these, but they are not // as efficient as they could be: they have commonType // as the receiver instead of *commonType. -func (t *commonType) NumMethod() int { return t.uncommonType.NumMethod() } +func (t *commonType) NumMethod() int { + if t.Kind() == Interface { + tt := (*interfaceType)(unsafe.Pointer(t)) + return tt.NumMethod() + } + return t.uncommonType.NumMethod() +} -func (t *commonType) Method(i int) (m Method) { return t.uncommonType.Method(i) } +func (t *commonType) Method(i int) (m Method) { + if t.Kind() == Interface { + tt := (*interfaceType)(unsafe.Pointer(t)) + return tt.Method(i) + } + return t.uncommonType.Method(i) +} -func (t *commonType) PkgPath() string { return t.uncommonType.PkgPath() } +func (t *commonType) PkgPath() string { + return t.uncommonType.PkgPath() +} -func (t *commonType) Name() string { return t.uncommonType.Name() } +func (t *commonType) Name() string { + return t.uncommonType.Name() +} -// Len returns the number of elements in the array. -func (t *ArrayType) Len() int { return int(t.len) } +func (t *commonType) ChanDir() ChanDir { + if t.Kind() != Chan { + panic("reflect: ChanDir of non-chan type") + } + tt := (*chanType)(unsafe.Pointer(t)) + return ChanDir(tt.dir) +} -// Elem returns the type of the array's elements. -func (t *ArrayType) Elem() Type { return runtimeToType(t.elem) } +func (t *commonType) IsVariadic() bool { + if t.Kind() != Func { + panic("reflect: IsVariadic of non-func type") + } + tt := (*funcType)(unsafe.Pointer(t)) + return tt.dotdotdot +} -// arrayOrSliceType is an unexported method that guarantees only -// arrays and slices implement ArrayOrSliceType. -func (*ArrayType) arrayOrSliceType() {} +func (t *commonType) Elem() Type { + switch t.Kind() { + case Array: + tt := (*arrayType)(unsafe.Pointer(t)) + return toType(tt.elem) + case Chan: + tt := (*chanType)(unsafe.Pointer(t)) + return toType(tt.elem) + case Map: + tt := (*mapType)(unsafe.Pointer(t)) + return toType(tt.elem) + case Ptr: + tt := (*ptrType)(unsafe.Pointer(t)) + return toType(tt.elem) + case Slice: + tt := (*sliceType)(unsafe.Pointer(t)) + return toType(tt.elem) + } + panic("reflect; Elem of invalid type") +} -// Dir returns the channel direction. -func (t *ChanType) Dir() ChanDir { return ChanDir(t.dir) } +func (t *commonType) Field(i int) StructField { + if t.Kind() != Struct { + panic("reflect: Field of non-struct type") + } + tt := (*structType)(unsafe.Pointer(t)) + return tt.Field(i) +} -// Elem returns the channel's element type. -func (t *ChanType) Elem() Type { return runtimeToType(t.elem) } +func (t *commonType) FieldByIndex(index []int) StructField { + if t.Kind() != Struct { + panic("reflect: FieldByIndex of non-struct type") + } + tt := (*structType)(unsafe.Pointer(t)) + return tt.FieldByIndex(index) +} -func (d ChanDir) String() string { - switch d { - case SendDir: - return "chan<-" - case RecvDir: - return "<-chan" - case BothDir: - return "chan" +func (t *commonType) FieldByName(name string) (StructField, bool) { + if t.Kind() != Struct { + panic("reflect: FieldByName of non-struct type") } - return "ChanDir" + strconv.Itoa(int(d)) + tt := (*structType)(unsafe.Pointer(t)) + return tt.FieldByName(name) } -// In returns the type of the i'th function input parameter. -func (t *FuncType) In(i int) Type { - if i < 0 || i >= len(t.in) { - return nil +func (t *commonType) FieldByNameFunc(match func(string) bool) (StructField, bool) { + if t.Kind() != Struct { + panic("reflect: FieldByNameFunc of non-struct type") } - return runtimeToType(t.in[i]) + tt := (*structType)(unsafe.Pointer(t)) + return tt.FieldByNameFunc(match) } -// DotDotDot returns true if the final function input parameter -// is a "..." parameter. If so, t.In(t.NumIn() - 1) returns the -// parameter's underlying static type []T. -// -// For concreteness, if t is func(x int, y ... float), then -// -// t.NumIn() == 2 -// t.In(0) is the reflect.Type for "int" -// t.In(1) is the reflect.Type for "[]float" -// t.DotDotDot() == true -// -func (t *FuncType) DotDotDot() bool { return t.dotdotdot } +func (t *commonType) In(i int) Type { + if t.Kind() != Func { + panic("reflect: In of non-func type") + } + tt := (*funcType)(unsafe.Pointer(t)) + return toType(tt.in[i]) +} -// NumIn returns the number of input parameters. -func (t *FuncType) NumIn() int { return len(t.in) } +func (t *commonType) Key() Type { + if t.Kind() != Map { + panic("reflect: Key of non-map type") + } + tt := (*mapType)(unsafe.Pointer(t)) + return toType(tt.key) +} -// Out returns the type of the i'th function output parameter. -func (t *FuncType) Out(i int) Type { - if i < 0 || i >= len(t.out) { - return nil +func (t *commonType) Len() int { + if t.Kind() != Array { + panic("reflect: Len of non-array type") } - return runtimeToType(t.out[i]) + tt := (*arrayType)(unsafe.Pointer(t)) + return int(tt.len) } -// NumOut returns the number of function output parameters. -func (t *FuncType) NumOut() int { return len(t.out) } +func (t *commonType) NumField() int { + if t.Kind() != Struct { + panic("reflect: NumField of non-struct type") + } + tt := (*structType)(unsafe.Pointer(t)) + return len(tt.fields) +} + +func (t *commonType) NumIn() int { + if t.Kind() != Func { + panic("reflect; NumIn of non-func type") + } + tt := (*funcType)(unsafe.Pointer(t)) + return len(tt.in) +} + +func (t *commonType) NumOut() int { + if t.Kind() != Func { + panic("reflect; NumOut of non-func type") + } + tt := (*funcType)(unsafe.Pointer(t)) + return len(tt.out) +} + +func (t *commonType) Out(i int) Type { + if t.Kind() != Func { + panic("reflect: Out of non-func type") + } + tt := (*funcType)(unsafe.Pointer(t)) + return toType(tt.out[i]) +} + +func (d ChanDir) String() string { + switch d { + case SendDir: + return "chan<-" + case RecvDir: + return "<-chan" + case BothDir: + return "chan" + } + return "ChanDir" + strconv.Itoa(int(d)) +} // Method returns the i'th method in the type's method set. -func (t *InterfaceType) Method(i int) (m Method) { +func (t *interfaceType) Method(i int) (m Method) { if i < 0 || i >= len(t.methods) { return } @@ -474,24 +636,12 @@ func (t *InterfaceType) Method(i int) (m Method) { if p.pkgPath != nil { m.PkgPath = *p.pkgPath } - m.Type = runtimeToType(p.typ).(*FuncType) + m.Type = toType(p.typ) return } // NumMethod returns the number of interface methods in the type's method set. -func (t *InterfaceType) NumMethod() int { return len(t.methods) } - -// Key returns the map key type. -func (t *MapType) Key() Type { return runtimeToType(t.key) } - -// Elem returns the map element type. -func (t *MapType) Elem() Type { return runtimeToType(t.elem) } - -// Elem returns the pointer element type. -func (t *PtrType) Elem() Type { return runtimeToType(t.elem) } - -// Elem returns the type of the slice's elements. -func (t *SliceType) Elem() Type { return runtimeToType(t.elem) } +func (t *interfaceType) NumMethod() int { return len(t.methods) } type StructField struct { PkgPath string // empty for uppercase Name @@ -504,18 +654,18 @@ type StructField struct { } // Field returns the i'th struct field. -func (t *StructType) Field(i int) (f StructField) { +func (t *structType) Field(i int) (f StructField) { if i < 0 || i >= len(t.fields) { return } p := t.fields[i] - f.Type = runtimeToType(p.typ) + f.Type = toType(p.typ) if p.name != nil { f.Name = *p.name } else { t := f.Type - if pt, ok := t.(*PtrType); ok { - t = pt.Elem() + if t.Kind() == Ptr { + t = t.Elem() } f.Name = t.Name() f.Anonymous = true @@ -535,29 +685,24 @@ func (t *StructType) Field(i int) (f StructField) { // is wrong for FieldByIndex? // FieldByIndex returns the nested field corresponding to index. -func (t *StructType) FieldByIndex(index []int) (f StructField) { +func (t *structType) FieldByIndex(index []int) (f StructField) { + f.Type = Type(t.toType()) for i, x := range index { if i > 0 { ft := f.Type - if pt, ok := ft.(*PtrType); ok { - ft = pt.Elem() - } - if st, ok := ft.(*StructType); ok { - t = st - } else { - var f0 StructField - f = f0 - return + if ft.Kind() == Ptr && ft.Elem().Kind() == Struct { + ft = ft.Elem() } + f.Type = ft } - f = t.Field(x) + f = f.Type.Field(x) } return } const inf = 1 << 30 // infinity - no struct has that many nesting levels -func (t *StructType) fieldByNameFunc(match func(string) bool, mark map[*StructType]bool, depth int) (ff StructField, fd int) { +func (t *structType) fieldByNameFunc(match func(string) bool, mark map[*structType]bool, depth int) (ff StructField, fd int) { fd = inf // field depth if mark[t] { @@ -578,8 +723,8 @@ L: d = depth case f.Anonymous: ft := f.Type - if pt, ok := ft.(*PtrType); ok { - ft = pt.Elem() + if ft.Kind() == Ptr { + ft = ft.Elem() } switch { case match(ft.Name()): @@ -587,7 +732,8 @@ L: d = depth case fd > depth: // No top-level field yet; look inside nested structs. - if st, ok := ft.(*StructType); ok { + if ft.Kind() == Struct { + st := (*structType)(unsafe.Pointer(ft.(*commonType))) f, d = st.fieldByNameFunc(match, mark, depth+1) } } @@ -626,27 +772,36 @@ L: // FieldByName returns the struct field with the given name // and a boolean to indicate if the field was found. -func (t *StructType) FieldByName(name string) (f StructField, present bool) { +func (t *structType) FieldByName(name string) (f StructField, present bool) { return t.FieldByNameFunc(func(s string) bool { return s == name }) } // FieldByNameFunc returns the struct field with a name that satisfies the // match function and a boolean to indicate if the field was found. -func (t *StructType) FieldByNameFunc(match func(string) bool) (f StructField, present bool) { - if ff, fd := t.fieldByNameFunc(match, make(map[*StructType]bool), 0); fd < inf { +func (t *structType) FieldByNameFunc(match func(string) bool) (f StructField, present bool) { + if ff, fd := t.fieldByNameFunc(match, make(map[*structType]bool), 0); fd < inf { ff.Index = ff.Index[0 : fd+1] f, present = ff, true } return } -// NumField returns the number of struct fields. -func (t *StructType) NumField() int { return len(t.fields) } +// Convert runtime type to reflect type. +func toCommonType(p *runtime.Type) *commonType { + if p == nil { + return nil + } + x := unsafe.Pointer(p) + if uintptr(x)&reflectFlags != 0 { + panic("invalid interface value") + } + return (*commonType)(x) +} // Canonicalize a Type. var canonicalType = make(map[string]Type) -var canonicalTypeLock sync.Mutex +var canonicalTypeLock sync.RWMutex func canonicalize(t Type) Type { if t == nil { @@ -659,6 +814,12 @@ func canonicalize(t Type) Type { } else { s = u.PkgPath() + "." + u.Name() } + canonicalTypeLock.RLock() + if r, ok := canonicalType[s]; ok { + canonicalTypeLock.RUnlock() + return r + } + canonicalTypeLock.RUnlock() canonicalTypeLock.Lock() if r, ok := canonicalType[s]; ok { canonicalTypeLock.Unlock() @@ -669,121 +830,36 @@ func canonicalize(t Type) Type { return t } -// Convert runtime type to reflect type. -// Same memory layouts, different method sets. -func toType(i interface{}) Type { - switch v := i.(type) { - case nil: +func toType(p *runtime.Type) Type { + if p == nil { return nil - case *runtime.BoolType: - return (*BoolType)(unsafe.Pointer(v)) - case *runtime.FloatType: - return (*FloatType)(unsafe.Pointer(v)) - case *runtime.ComplexType: - return (*ComplexType)(unsafe.Pointer(v)) - case *runtime.IntType: - return (*IntType)(unsafe.Pointer(v)) - case *runtime.StringType: - return (*StringType)(unsafe.Pointer(v)) - case *runtime.UintType: - return (*UintType)(unsafe.Pointer(v)) - case *runtime.UnsafePointerType: - return (*UnsafePointerType)(unsafe.Pointer(v)) - case *runtime.ArrayType: - return (*ArrayType)(unsafe.Pointer(v)) - case *runtime.ChanType: - return (*ChanType)(unsafe.Pointer(v)) - case *runtime.FuncType: - return (*FuncType)(unsafe.Pointer(v)) - case *runtime.InterfaceType: - return (*InterfaceType)(unsafe.Pointer(v)) - case *runtime.MapType: - return (*MapType)(unsafe.Pointer(v)) - case *runtime.PtrType: - return (*PtrType)(unsafe.Pointer(v)) - case *runtime.SliceType: - return (*SliceType)(unsafe.Pointer(v)) - case *runtime.StructType: - return (*StructType)(unsafe.Pointer(v)) - } - println(i) - panic("toType") -} - -// Convert pointer to runtime Type structure to our Type structure. -func runtimeToType(v *runtime.Type) Type { - var r Type - switch Kind(v.Kind) { - case Bool: - r = (*BoolType)(unsafe.Pointer(v)) - case Int, Int8, Int16, Int32, Int64: - r = (*IntType)(unsafe.Pointer(v)) - case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr: - r = (*UintType)(unsafe.Pointer(v)) - case Float32, Float64: - r = (*FloatType)(unsafe.Pointer(v)) - case Complex64, Complex128: - r = (*ComplexType)(unsafe.Pointer(v)) - case Array: - r = (*ArrayType)(unsafe.Pointer(v)) - case Chan: - r = (*ChanType)(unsafe.Pointer(v)) - case Func: - r = (*FuncType)(unsafe.Pointer(v)) - case Interface: - r = (*InterfaceType)(unsafe.Pointer(v)) - case Map: - r = (*MapType)(unsafe.Pointer(v)) - case Ptr: - r = (*PtrType)(unsafe.Pointer(v)) - case Slice: - r = (*SliceType)(unsafe.Pointer(v)) - case String: - r = (*StringType)(unsafe.Pointer(v)) - case Struct: - r = (*StructType)(unsafe.Pointer(v)) - case UnsafePointer: - r = (*UnsafePointerType)(unsafe.Pointer(v)) - default: - panic("runtimeToType") } - return canonicalize(r) - panic("runtimeToType") + return toCommonType(p).toType() } -// ArrayOrSliceType is the common interface implemented -// by both ArrayType and SliceType. -type ArrayOrSliceType interface { - Type - Elem() Type - arrayOrSliceType() // Guarantees only Array and Slice implement this interface. +// TypeOf returns the reflection Type of the value in the interface{}. +func TypeOf(i interface{}) Type { + eface := *(*emptyInterface)(unsafe.Pointer(&i)) + return toType(eface.typ) } -// Typeof returns the reflection Type of the value in the interface{}. -func Typeof(i interface{}) Type { return canonicalize(toType(unsafe.Typeof(i))) } - // ptrMap is the cache for PtrTo. var ptrMap struct { sync.RWMutex - m map[Type]*PtrType + m map[*commonType]*ptrType } -// runtimePtrType is the runtime layout for a *PtrType. -// The memory immediately before the *PtrType is always -// the canonical runtime.Type to be used for a *runtime.Type -// describing this PtrType. -type runtimePtrType struct { - runtime.Type - runtime.PtrType +func (t *commonType) runtimeType() *runtime.Type { + return (*runtime.Type)(unsafe.Pointer(t)) } // PtrTo returns the pointer type with element t. // For example, if t represents type Foo, PtrTo(t) represents *Foo. -func PtrTo(t Type) *PtrType { +func PtrTo(t Type) Type { // If t records its pointer-to type, use it. - ct := t.common() + ct := t.(*commonType) if p := ct.ptrToThis; p != nil { - return runtimeToType(p).(*PtrType) + return toType(p) } // Otherwise, synthesize one. @@ -793,17 +869,17 @@ func PtrTo(t Type) *PtrType { // the type structures in read-only memory. ptrMap.RLock() if m := ptrMap.m; m != nil { - if p := m[t]; p != nil { + if p := m[ct]; p != nil { ptrMap.RUnlock() - return p + return p.commonType.toType() } } ptrMap.RUnlock() ptrMap.Lock() if ptrMap.m == nil { - ptrMap.m = make(map[Type]*PtrType) + ptrMap.m = make(map[*commonType]*ptrType) } - p := ptrMap.m[t] + p := ptrMap.m[ct] if p != nil { // some other goroutine won the race and created it ptrMap.Unlock() @@ -814,11 +890,11 @@ func PtrTo(t Type) *PtrType { rp := new(runtime.PtrType) - // initialize rp used *byte's PtrType as a prototype. + // initialize p using *byte's PtrType as a prototype. // have to do assignment as PtrType, not runtime.PtrType, // in order to write to unexported fields. - p = (*PtrType)(unsafe.Pointer(rp)) - bp := (*PtrType)(unsafe.Pointer(unsafe.Typeof((*byte)(nil)).(*runtime.PtrType))) + p = (*ptrType)(unsafe.Pointer(rp)) + bp := (*ptrType)(unsafe.Pointer(unsafe.Typeof((*byte)(nil)).(*runtime.PtrType))) *p = *bp s := "*" + *ct.string @@ -833,9 +909,170 @@ func PtrTo(t Type) *PtrType { p.uncommonType = nil p.ptrToThis = nil - p.elem = rt + p.elem = (*runtime.Type)(unsafe.Pointer(ct)) - ptrMap.m[t] = (*PtrType)(unsafe.Pointer(rp)) + ptrMap.m[ct] = p ptrMap.Unlock() - return p + return p.commonType.toType() +} + +func (t *commonType) Implements(u Type) bool { + if u == nil { + panic("reflect: nil type passed to Type.Implements") + } + if u.Kind() != Interface { + panic("reflect: non-interface type passed to Type.Implements") + } + return implements(u.(*commonType), t) +} + +func (t *commonType) AssignableTo(u Type) bool { + if u == nil { + panic("reflect: nil type passed to Type.AssignableTo") + } + uu := u.(*commonType) + return directlyAssignable(uu, t) || implements(uu, t) +} + +// implements returns true if the type V implements the interface type T. +func implements(T, V *commonType) bool { + if T.Kind() != Interface { + return false + } + t := (*interfaceType)(unsafe.Pointer(T)) + if len(t.methods) == 0 { + return true + } + + // The same algorithm applies in both cases, but the + // method tables for an interface type and a concrete type + // are different, so the code is duplicated. + // In both cases the algorithm is a linear scan over the two + // lists - T's methods and V's methods - simultaneously. + // Since method tables are stored in a unique sorted order + // (alphabetical, with no duplicate method names), the scan + // through V's methods must hit a match for each of T's + // methods along the way, or else V does not implement T. + // This lets us run the scan in overall linear time instead of + // the quadratic time a naive search would require. + // See also ../runtime/iface.c. + if V.Kind() == Interface { + v := (*interfaceType)(unsafe.Pointer(V)) + i := 0 + for j := 0; j < len(v.methods); j++ { + tm := &t.methods[i] + vm := &v.methods[j] + if *vm.name == *tm.name && (vm.pkgPath == tm.pkgPath || (vm.pkgPath != nil && tm.pkgPath != nil && *vm.pkgPath == *tm.pkgPath)) && toType(vm.typ).common() == toType(tm.typ).common() { + if i++; i >= len(t.methods) { + return true + } + } + } + return false + } + + v := V.uncommon() + if v == nil { + return false + } + i := 0 + for j := 0; j < len(v.methods); j++ { + tm := &t.methods[i] + vm := &v.methods[j] + if *vm.name == *tm.name && (vm.pkgPath == tm.pkgPath || (vm.pkgPath != nil && tm.pkgPath != nil && *vm.pkgPath == *tm.pkgPath)) && toType(vm.mtyp).common() == toType(tm.typ).common() { + if i++; i >= len(t.methods) { + return true + } + } + } + return false +} + +// directlyAssignable returns true if a value x of type V can be directly +// assigned (using memmove) to a value of type T. +// http://golang.org/doc/go_spec.html#Assignability +// Ignoring the interface rules (implemented elsewhere) +// and the ideal constant rules (no ideal constants at run time). +func directlyAssignable(T, V *commonType) bool { + // x's type V is identical to T? + if T == V { + return true + } + + // Otherwise at least one of T and V must be unnamed + // and they must have the same kind. + if T.Name() != "" && V.Name() != "" || T.Kind() != V.Kind() { + return false + } + + // x's type T and V have identical underlying types. + // Since at least one is unnamed, only the composite types + // need to be considered. + switch T.Kind() { + case Array: + return T.Elem() == V.Elem() && T.Len() == V.Len() + + case Chan: + // Special case: + // x is a bidirectional channel value, T is a channel type, + // and x's type V and T have identical element types. + if V.ChanDir() == BothDir && T.Elem() == V.Elem() { + return true + } + + // Otherwise continue test for identical underlying type. + return V.ChanDir() == T.ChanDir() && T.Elem() == V.Elem() + + case Func: + t := (*funcType)(unsafe.Pointer(T)) + v := (*funcType)(unsafe.Pointer(V)) + if t.dotdotdot != v.dotdotdot || len(t.in) != len(v.in) || len(t.out) != len(v.out) { + return false + } + for i, typ := range t.in { + if typ != v.in[i] { + return false + } + } + for i, typ := range t.out { + if typ != v.out[i] { + return false + } + } + return true + + case Interface: + t := (*interfaceType)(unsafe.Pointer(T)) + v := (*interfaceType)(unsafe.Pointer(V)) + if len(t.methods) == 0 && len(v.methods) == 0 { + return true + } + // Might have the same methods but still + // need a run time conversion. + return false + + case Map: + return T.Key() == V.Key() && T.Elem() == V.Elem() + + case Ptr, Slice: + return T.Elem() == V.Elem() + + case Struct: + t := (*structType)(unsafe.Pointer(T)) + v := (*structType)(unsafe.Pointer(V)) + if len(t.fields) != len(v.fields) { + return false + } + for i := range t.fields { + tf := &t.fields[i] + vf := &v.fields[i] + if tf.name != vf.name || tf.pkgPath != vf.pkgPath || + tf.typ != vf.typ || tf.tag != vf.tag || tf.offset != vf.offset { + return false + } + } + return true + } + + return false } diff --git a/libgo/go/reflect/value.go b/libgo/go/reflect/value.go index ebc87d45b92..ea48b02f14b 100644 --- a/libgo/go/reflect/value.go +++ b/libgo/go/reflect/value.go @@ -7,17 +7,16 @@ package reflect import ( "math" "runtime" + "strconv" "unsafe" ) const ptrSize = uintptr(unsafe.Sizeof((*byte)(nil))) const cannotSet = "cannot set value obtained from unexported struct field" -type addr unsafe.Pointer - // TODO: This will have to go away when // the new gc goes in. -func memmove(adst, asrc addr, n uintptr) { +func memmove(adst, asrc unsafe.Pointer, n uintptr) { dst := uintptr(adst) src := uintptr(asrc) switch { @@ -26,1291 +25,1667 @@ func memmove(adst, asrc addr, n uintptr) { // careful: i is unsigned for i := n; i > 0; { i-- - *(*byte)(addr(dst + i)) = *(*byte)(addr(src + i)) + *(*byte)(unsafe.Pointer(dst + i)) = *(*byte)(unsafe.Pointer(src + i)) } case (n|src|dst)&(ptrSize-1) != 0: // byte copy forward for i := uintptr(0); i < n; i++ { - *(*byte)(addr(dst + i)) = *(*byte)(addr(src + i)) + *(*byte)(unsafe.Pointer(dst + i)) = *(*byte)(unsafe.Pointer(src + i)) } default: // word copy forward for i := uintptr(0); i < n; i += ptrSize { - *(*uintptr)(addr(dst + i)) = *(*uintptr)(addr(src + i)) + *(*uintptr)(unsafe.Pointer(dst + i)) = *(*uintptr)(unsafe.Pointer(src + i)) } } } -// Value is the common interface to reflection values. -// The implementations of Value (e.g., ArrayValue, StructValue) -// have additional type-specific methods. -type Value interface { - // Type returns the value's type. - Type() Type - - // Interface returns the value as an interface{}. - Interface() interface{} +// Value is the reflection interface to a Go value. +// +// Not all methods apply to all kinds of values. Restrictions, +// if any, are noted in the documentation for each method. +// Use the Kind method to find out the kind of value before +// calling kind-specific methods. Calling a method +// inappropriate to the kind of type causes a run time panic. +// +// The zero Value represents no value. +// Its IsValid method returns false, its Kind method returns Invalid, +// its String method returns "", and all other methods panic. +// Most functions and methods never return an invalid value. +// If one does, its documentation states the conditions explicitly. +// +// The fields of Value are exported so that clients can copy and +// pass Values around, but they should not be edited or inspected +// directly. A future language change may make it possible not to +// export these fields while still keeping Values usable as values. +type Value struct { + Internal interface{} + InternalMethod int +} + +// A ValueError occurs when a Value method is invoked on +// a Value that does not support it. Such cases are documented +// in the description of each method. +type ValueError struct { + Method string + Kind Kind +} + +func (e *ValueError) String() string { + if e.Kind == 0 { + return "reflect: call of " + e.Method + " on zero Value" + } + return "reflect: call of " + e.Method + " on " + e.Kind.String() + " Value" +} + +// methodName returns the name of the calling method, +// assumed to be two stack frames above. +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +// An iword is the word that would be stored in an +// interface to represent a given value v. Specifically, if v is +// bigger than a pointer, its word is a pointer to v's data. +// Otherwise, its word is a zero uintptr with the data stored +// in the leading bytes. +type iword uintptr + +func loadIword(p unsafe.Pointer, size uintptr) iword { + // Run the copy ourselves instead of calling memmove + // to avoid moving v to the heap. + w := iword(0) + switch size { + default: + panic("reflect: internal error: loadIword of " + strconv.Itoa(int(size)) + "-byte value") + case 0: + case 1: + *(*uint8)(unsafe.Pointer(&w)) = *(*uint8)(p) + case 2: + *(*uint16)(unsafe.Pointer(&w)) = *(*uint16)(p) + case 3: + *(*[3]byte)(unsafe.Pointer(&w)) = *(*[3]byte)(p) + case 4: + *(*uint32)(unsafe.Pointer(&w)) = *(*uint32)(p) + case 5: + *(*[5]byte)(unsafe.Pointer(&w)) = *(*[5]byte)(p) + case 6: + *(*[6]byte)(unsafe.Pointer(&w)) = *(*[6]byte)(p) + case 7: + *(*[7]byte)(unsafe.Pointer(&w)) = *(*[7]byte)(p) + case 8: + *(*uint64)(unsafe.Pointer(&w)) = *(*uint64)(p) + } + return w +} + +func storeIword(p unsafe.Pointer, w iword, size uintptr) { + // Run the copy ourselves instead of calling memmove + // to avoid moving v to the heap. + switch size { + default: + panic("reflect: internal error: storeIword of " + strconv.Itoa(int(size)) + "-byte value") + case 0: + case 1: + *(*uint8)(p) = *(*uint8)(unsafe.Pointer(&w)) + case 2: + *(*uint16)(p) = *(*uint16)(unsafe.Pointer(&w)) + case 3: + *(*[3]byte)(p) = *(*[3]byte)(unsafe.Pointer(&w)) + case 4: + *(*uint32)(p) = *(*uint32)(unsafe.Pointer(&w)) + case 5: + *(*[5]byte)(p) = *(*[5]byte)(unsafe.Pointer(&w)) + case 6: + *(*[6]byte)(p) = *(*[6]byte)(unsafe.Pointer(&w)) + case 7: + *(*[7]byte)(p) = *(*[7]byte)(unsafe.Pointer(&w)) + case 8: + *(*uint64)(p) = *(*uint64)(unsafe.Pointer(&w)) + } +} + +// emptyInterface is the header for an interface{} value. +type emptyInterface struct { + typ *runtime.Type + word iword +} + +// nonEmptyInterface is the header for a interface value with methods. +type nonEmptyInterface struct { + // see ../runtime/iface.c:/Itab + itab *struct { + typ *runtime.Type // dynamic concrete type + fun [100000]unsafe.Pointer // method table + } + word iword +} + +// Regarding the implementation of Value: +// +// The Internal interface is a true interface value in the Go sense, +// but it also serves as a (type, address) pair in whcih one cannot +// be changed separately from the other. That is, it serves as a way +// to prevent unsafe mutations of the Internal state even though +// we cannot (yet?) hide the field while preserving the ability for +// clients to make copies of Values. +// +// The internal method converts a Value into the expanded internalValue struct. +// If we could avoid exporting fields we'd probably make internalValue the +// definition of Value. +// +// If a Value is addressable (CanAddr returns true), then the Internal +// interface value holds a pointer to the actual field data, and Set stores +// through that pointer. If a Value is not addressable (CanAddr returns false), +// then the Internal interface value holds the actual value. +// +// In addition to whether a value is addressable, we track whether it was +// obtained by using an unexported struct field. Such values are allowed +// to be read, mainly to make fmt.Print more useful, but they are not +// allowed to be written. We call such values read-only. +// +// A Value can be set (via the Set, SetUint, etc. methods) only if it is both +// addressable and not read-only. +// +// The two permission bits - addressable and read-only - are stored in +// the bottom two bits of the type pointer in the interface value. +// +// ordinary value: Internal = value +// addressable value: Internal = value, Internal.typ |= flagAddr +// read-only value: Internal = value, Internal.typ |= flagRO +// addressable, read-only value: Internal = value, Internal.typ |= flagAddr | flagRO +// +// It is important that the read-only values have the extra bit set +// (as opposed to using the bit to mean writable), because client code +// can grab the interface field and try to use it. Having the extra bit +// set makes the type pointer compare not equal to any real type, +// so that a client cannot, say, write through v.Internal.(*int). +// The runtime routines that access interface types reject types with +// low bits set. +// +// If a Value fv = v.Method(i), then fv = v with the InternalMethod +// field set to i+1. Methods are never addressable. +// +// All in all, this is a lot of effort just to avoid making this new API +// depend on a language change we'll probably do anyway, but +// it's helpful to keep the two separate, and much of the logic is +// necessary to implement the Interface method anyway. - // CanSet returns true if the value can be changed. - // Values obtained by the use of non-exported struct fields - // can be used in Get but not Set. - // If CanSet returns false, calling the type-specific Set will panic. - CanSet() bool +const ( + flagAddr uint32 = 1 << iota // holds address of value + flagRO // read-only - // SetValue assigns v to the value; v must have the same type as the value. - SetValue(v Value) + reflectFlags = 3 +) - // CanAddr returns true if the value's address can be obtained with Addr. - // Such values are called addressable. A value is addressable if it is - // an element of a slice, an element of an addressable array, - // a field of an addressable struct, the result of dereferencing a pointer, - // or the result of a call to NewValue, MakeChan, MakeMap, or MakeZero. - // If CanAddr returns false, calling Addr will panic. - CanAddr() bool +// An internalValue is the unpacked form of a Value. +// The zero Value unpacks to a zero internalValue +type internalValue struct { + typ *commonType // type of value + kind Kind // kind of value + flag uint32 + word iword + addr unsafe.Pointer + rcvr iword + method bool + nilmethod bool +} + +func (v Value) internal() internalValue { + var iv internalValue + eface := *(*emptyInterface)(unsafe.Pointer(&v.Internal)) + p := uintptr(unsafe.Pointer(eface.typ)) + iv.typ = toCommonType((*runtime.Type)(unsafe.Pointer(p &^ reflectFlags))) + if iv.typ == nil { + return iv + } + iv.flag = uint32(p & reflectFlags) + iv.word = eface.word + if iv.flag&flagAddr != 0 { + iv.addr = unsafe.Pointer(uintptr(iv.word)) + iv.typ = iv.typ.Elem().common() + if Kind(iv.typ.kind) == Ptr || Kind(iv.typ.kind) == UnsafePointer { + iv.word = loadIword(iv.addr, iv.typ.size) + } + } else { + if Kind(iv.typ.kind) != Ptr && Kind(iv.typ.kind) != UnsafePointer { + iv.addr = unsafe.Pointer(uintptr(iv.word)) + } + } + iv.kind = iv.typ.Kind() + + // Is this a method? If so, iv describes the receiver. + // Rewrite to describe the method function. + if v.InternalMethod != 0 { + // If this Value is a method value (x.Method(i) for some Value x) + // then we will invoke it using the interface form of the method, + // which always passes the receiver as a single word. + // Record that information. + i := v.InternalMethod - 1 + if iv.kind == Interface { + it := (*interfaceType)(unsafe.Pointer(iv.typ)) + if i < 0 || i >= len(it.methods) { + panic("reflect: broken Value") + } + m := &it.methods[i] + if m.pkgPath != nil { + iv.flag |= flagRO + } + iv.typ = toCommonType(m.typ) + iface := (*nonEmptyInterface)(iv.addr) + if iface.itab == nil { + iv.word = 0 + iv.nilmethod = true + } else { + iv.word = iword(uintptr(iface.itab.fun[i])) + } + iv.rcvr = iface.word + } else { + ut := iv.typ.uncommon() + if ut == nil || i < 0 || i >= len(ut.methods) { + panic("reflect: broken Value") + } + m := &ut.methods[i] + if m.pkgPath != nil { + iv.flag |= flagRO + } + iv.typ = toCommonType(m.mtyp) + iv.rcvr = iv.word + iv.word = iword(uintptr(m.tfn)) + } + if iv.word != 0 { + p := new(iword) + *p = iv.word + iv.word = iword(uintptr(unsafe.Pointer(p))) + } + iv.kind = Func + iv.method = true + iv.flag &^= flagAddr + iv.addr = unsafe.Pointer(uintptr(iv.word)) + } - // Addr returns the address of the value. - // If the value is not addressable, Addr panics. - // Addr is typically used to obtain a pointer to a struct field or slice element - // in order to call a method that requires a pointer receiver. - Addr() *PtrValue + return iv +} - // UnsafeAddr returns a pointer to the underlying data. - // It is for advanced clients that also import the "unsafe" package. - UnsafeAddr() uintptr +// packValue returns a Value with the given flag bits, type, and interface word. +func packValue(flag uint32, typ *runtime.Type, word iword) Value { + if typ == nil { + panic("packValue") + } + t := uintptr(unsafe.Pointer(typ)) + t |= uintptr(flag) + eface := emptyInterface{(*runtime.Type)(unsafe.Pointer(t)), word} + return Value{Internal: *(*interface{})(unsafe.Pointer(&eface))} +} - // Method returns a FuncValue corresponding to the value's i'th method. - // The arguments to a Call on the returned FuncValue - // should not include a receiver; the FuncValue will use - // the value as the receiver. - Method(i int) *FuncValue +// valueFromAddr returns a Value using the given type and address. +func valueFromAddr(flag uint32, typ Type, addr unsafe.Pointer) Value { + if flag&flagAddr != 0 { + // Addressable, so the internal value is + // an interface containing a pointer to the real value. + return packValue(flag, PtrTo(typ).runtimeType(), iword(uintptr(addr))) + } - getAddr() addr + var w iword + if k := typ.Kind(); k == Ptr || k == UnsafePointer { + // In line, so the interface word is the actual value. + w = loadIword(addr, typ.Size()) + } else { + // Not in line: the interface word is the address. + w = iword(uintptr(addr)) + } + return packValue(flag, typ.runtimeType(), w) } -// flags for value -const ( - canSet uint32 = 1 << iota // can set value (write to *v.addr) - canAddr // can take address of value - canStore // can store through value (write to **v.addr) -) - -// value is the common implementation of most values. -// It is embedded in other, public struct types, but always -// with a unique tag like "uint" or "float" so that the client cannot -// convert from, say, *UintValue to *FloatValue. -type value struct { - typ Type - addr addr - flag uint32 +// valueFromIword returns a Value using the given type and interface word. +func valueFromIword(flag uint32, typ Type, w iword) Value { + if flag&flagAddr != 0 { + panic("reflect: internal error: valueFromIword addressable") + } + return packValue(flag, typ.runtimeType(), w) } -func (v *value) Type() Type { return v.typ } +func (iv internalValue) mustBe(want Kind) { + if iv.kind != want { + panic(&ValueError{methodName(), iv.kind}) + } +} -func (v *value) Addr() *PtrValue { - if !v.CanAddr() { - panic("reflect: cannot take address of value") +func (iv internalValue) mustBeExported() { + if iv.kind == 0 { + panic(&ValueError{methodName(), iv.kind}) } - a := v.addr - flag := canSet - if v.CanSet() { - flag |= canStore + if iv.flag&flagRO != 0 { + panic(methodName() + " using value obtained using unexported field") } - // We could safely set canAddr here too - - // the caller would get the address of a - - // but it doesn't match the Go model. - // The language doesn't let you say &&v. - return newValue(PtrTo(v.typ), addr(&a), flag).(*PtrValue) } -func (v *value) UnsafeAddr() uintptr { return uintptr(v.addr) } - -func (v *value) getAddr() addr { return v.addr } - -func (v *value) Interface() interface{} { - if typ, ok := v.typ.(*InterfaceType); ok { - // There are two different representations of interface values, - // one if the interface type has methods and one if it doesn't. - // These two representations require different expressions - // to extract correctly. - if typ.NumMethod() == 0 { - // Extract as interface value without methods. - return *(*interface{})(v.addr) - } - // Extract from v.addr as interface value with methods. - return *(*interface { - m() - })(v.addr) +func (iv internalValue) mustBeAssignable() { + if iv.kind == 0 { + panic(&ValueError{methodName(), iv.kind}) + } + // Assignable if addressable and not read-only. + if iv.flag&flagRO != 0 { + panic(methodName() + " using value obtained using unexported field") + } + if iv.flag&flagAddr == 0 { + panic(methodName() + " using unaddressable value") } - return unsafe.Unreflect(v.typ, unsafe.Pointer(v.addr)) } -func (v *value) CanSet() bool { return v.flag&canSet != 0 } +// Addr returns a pointer value representing the address of v. +// It panics if CanAddr() returns false. +// Addr is typically used to obtain a pointer to a struct field +// or slice element in order to call a method that requires a +// pointer receiver. +func (v Value) Addr() Value { + iv := v.internal() + if iv.flag&flagAddr == 0 { + panic("reflect.Value.Addr of unaddressable value") + } + return valueFromIword(iv.flag&flagRO, PtrTo(iv.typ.toType()), iword(uintptr(iv.addr))) +} -func (v *value) CanAddr() bool { return v.flag&canAddr != 0 } +// Bool returns v's underlying value. +// It panics if v's kind is not Bool. +func (v Value) Bool() bool { + iv := v.internal() + iv.mustBe(Bool) + return *(*bool)(unsafe.Pointer(iv.addr)) +} +// CanAddr returns true if the value's address can be obtained with Addr. +// Such values are called addressable. A value is addressable if it is +// an element of a slice, an element of an addressable array, +// a field of an addressable struct, or the result of dereferencing a pointer. +// If CanAddr returns false, calling Addr will panic. +func (v Value) CanAddr() bool { + iv := v.internal() + return iv.flag&flagAddr != 0 +} -/* - * basic types - */ +// CanSet returns true if the value of v can be changed. +// A Value can be changed only if it is addressable and was not +// obtained by the use of unexported struct fields. +// If CanSet returns false, calling Set or any type-specific +// setter (e.g., SetBool, SetInt64) will panic. +func (v Value) CanSet() bool { + iv := v.internal() + return iv.flag&(flagAddr|flagRO) == flagAddr +} -// BoolValue represents a bool value. -type BoolValue struct { - value "bool" +// Call calls the function v with the input arguments in. +// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]). +// Call panics if v's Kind is not Func. +// It returns the output results as Values. +// As in Go, each input argument must be assignable to the +// type of the function's corresponding input parameter. +// If v is a variadic function, Call creates the variadic slice parameter +// itself, copying in the corresponding values. +func (v Value) Call(in []Value) []Value { + iv := v.internal() + iv.mustBe(Func) + iv.mustBeExported() + return iv.call("Call", in) } -// Get returns the underlying bool value. -func (v *BoolValue) Get() bool { return *(*bool)(v.addr) } +// CallSlice calls the variadic function v with the input arguments in, +// assigning the slice in[len(in)-1] to v's final variadic argument. +// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]...). +// Call panics if v's Kind is not Func or if v is not variadic. +// It returns the output results as Values. +// As in Go, each input argument must be assignable to the +// type of the function's corresponding input parameter. +func (v Value) CallSlice(in []Value) []Value { + iv := v.internal() + iv.mustBe(Func) + iv.mustBeExported() + return iv.call("CallSlice", in) +} -// Set sets v to the value x. -func (v *BoolValue) Set(x bool) { - if !v.CanSet() { - panic(cannotSet) +func (iv internalValue) call(method string, in []Value) []Value { + if iv.word == 0 { + if iv.nilmethod { + panic("reflect.Value.Call: call of method on nil interface value") + } + panic("reflect.Value.Call: call of nil function") } - *(*bool)(v.addr) = x -} -// Set sets v to the value x. -func (v *BoolValue) SetValue(x Value) { v.Set(x.(*BoolValue).Get()) } + isSlice := method == "CallSlice" + t := iv.typ + n := t.NumIn() + if isSlice { + if !t.IsVariadic() { + panic("reflect: CallSlice of non-variadic function") + } + if len(in) < n { + panic("reflect: CallSlice with too few input arguments") + } + if len(in) > n { + panic("reflect: CallSlice with too many input arguments") + } + } else { + if t.IsVariadic() { + n-- + } + if len(in) < n { + panic("reflect: Call with too few input arguments") + } + if !t.IsVariadic() && len(in) > n { + panic("reflect: Call with too many input arguments") + } + } + for _, x := range in { + if x.Kind() == Invalid { + panic("reflect: " + method + " using zero Value argument") + } + } + for i := 0; i < n; i++ { + if xt, targ := in[i].Type(), t.In(i); !xt.AssignableTo(targ) { + panic("reflect: " + method + " using " + xt.String() + " as type " + targ.String()) + } + } + if !isSlice && t.IsVariadic() { + // prepare slice for remaining values + m := len(in) - n + slice := MakeSlice(t.In(n), m, m) + elem := t.In(n).Elem() + for i := 0; i < m; i++ { + x := in[n+i] + if xt := x.Type(); !xt.AssignableTo(elem) { + panic("reflect: cannot use " + xt.String() + " as type " + elem.String() + " in " + method) + } + slice.Index(i).Set(x) + } + origIn := in + in = make([]Value, n+1) + copy(in[:n], origIn) + in[n] = slice + } -// FloatValue represents a float value. -type FloatValue struct { - value "float" -} + nin := len(in) + if nin != t.NumIn() { + panic("reflect.Value.Call: wrong argument count") + } + nout := t.NumOut() -// Get returns the underlying int value. -func (v *FloatValue) Get() float64 { - switch v.typ.Kind() { - case Float32: - return float64(*(*float32)(v.addr)) - case Float64: - return *(*float64)(v.addr) + if iv.method { + nin++ + } + params := make([]unsafe.Pointer, nin) + delta := 0 + off := 0 + if iv.method { + // Hard-wired first argument. + p := new(iword) + *p = iv.rcvr + params[0] = unsafe.Pointer(p) + off = 1 } - panic("reflect: invalid float kind") -} -// Set sets v to the value x. -func (v *FloatValue) Set(x float64) { - if !v.CanSet() { - panic(cannotSet) + first_pointer := false + for i, v := range in { + siv := v.internal() + siv.mustBeExported() + targ := t.In(i).(*commonType) + siv = convertForAssignment("reflect.Value.Call", nil, targ, siv) + if siv.addr == nil { + p := new(unsafe.Pointer) + *p = unsafe.Pointer(uintptr(siv.word)) + params[off] = unsafe.Pointer(p) + } else { + params[off] = siv.addr + } + if i == 0 && Kind(targ.kind) != Ptr && !iv.method && isMethod(iv.typ) { + p := new(unsafe.Pointer) + *p = params[off] + params[off] = unsafe.Pointer(p) + first_pointer = true + } + off++ } - switch v.typ.Kind() { - default: - panic("reflect: invalid float kind") - case Float32: - *(*float32)(v.addr) = float32(x) - case Float64: - *(*float64)(v.addr) = x + + ret := make([]Value, nout) + results := make([]unsafe.Pointer, nout) + for i := 0; i < nout; i++ { + v := New(t.Out(i)) + results[i] = unsafe.Pointer(v.Pointer()) + ret[i] = Indirect(v) } + + call(t, *(*unsafe.Pointer)(iv.addr), iv.method, first_pointer, ¶ms[0], &results[0]) + + return ret } -// Overflow returns true if x cannot be represented by the type of v. -func (v *FloatValue) Overflow(x float64) bool { - if v.typ.Size() == 8 { +// gccgo specific test to see if typ is a method. We can tell by +// looking at the string to see if there is a receiver. We need this +// because for gccgo all methods take pointer receivers. +func isMethod(t *commonType) bool { + if Kind(t.kind) != Func { return false } - if x < 0 { - x = -x + s := *t.string + parens := 0 + params := 0 + sawRet := false + for i, c := range s { + if c == '(' { + parens++ + params++ + } else if c == ')' { + parens-- + } else if parens == 0 && c == ' ' && s[i + 1] != '(' && !sawRet { + params++ + sawRet = true + } } - return math.MaxFloat32 < x && x <= math.MaxFloat64 + return params > 2 } -// Set sets v to the value x. -func (v *FloatValue) SetValue(x Value) { v.Set(x.(*FloatValue).Get()) } +// Cap returns v's capacity. +// It panics if v's Kind is not Array, Chan, or Slice. +func (v Value) Cap() int { + iv := v.internal() + switch iv.kind { + case Array: + return iv.typ.Len() + case Chan: + return int(chancap(*(*iword)(iv.addr))) + case Slice: + return (*SliceHeader)(iv.addr).Cap + } + panic(&ValueError{"reflect.Value.Cap", iv.kind}) +} -// ComplexValue represents a complex value. -type ComplexValue struct { - value "complex" +// Close closes the channel v. +// It panics if v's Kind is not Chan. +func (v Value) Close() { + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + ch := *(*iword)(iv.addr) + chanclose(ch) } -// Get returns the underlying complex value. -func (v *ComplexValue) Get() complex128 { - switch v.typ.Kind() { +// Complex returns v's underlying value, as a complex128. +// It panics if v's Kind is not Complex64 or Complex128 +func (v Value) Complex() complex128 { + iv := v.internal() + switch iv.kind { case Complex64: - return complex128(*(*complex64)(v.addr)) + return complex128(*(*complex64)(iv.addr)) case Complex128: - return *(*complex128)(v.addr) + return *(*complex128)(iv.addr) } - panic("reflect: invalid complex kind") + panic(&ValueError{"reflect.Value.Complex", iv.kind}) } -// Set sets v to the value x. -func (v *ComplexValue) Set(x complex128) { - if !v.CanSet() { - panic(cannotSet) +// Elem returns the value that the interface v contains +// or that the pointer v points to. +// It panics if v's Kind is not Interface or Ptr. +// It returns the zero Value if v is nil. +func (v Value) Elem() Value { + iv := v.internal() + return iv.Elem() +} + +func (iv internalValue) Elem() Value { + switch iv.kind { + case Interface: + // Empty interface and non-empty interface have different layouts. + // Convert to empty interface. + var eface emptyInterface + if iv.typ.NumMethod() == 0 { + eface = *(*emptyInterface)(iv.addr) + } else { + iface := (*nonEmptyInterface)(iv.addr) + if iface.itab != nil { + eface.typ = iface.itab.typ + } + eface.word = iface.word + } + if eface.typ == nil { + return Value{} + } + return valueFromIword(iv.flag&flagRO, toType(eface.typ), eface.word) + + case Ptr: + // The returned value's address is v's value. + if iv.word == 0 { + return Value{} + } + return valueFromAddr(iv.flag&flagRO|flagAddr, iv.typ.Elem(), unsafe.Pointer(uintptr(iv.word))) } - switch v.typ.Kind() { - default: - panic("reflect: invalid complex kind") - case Complex64: - *(*complex64)(v.addr) = complex64(x) - case Complex128: - *(*complex128)(v.addr) = x + panic(&ValueError{"reflect.Value.Elem", iv.kind}) +} + +// Field returns the i'th field of the struct v. +// It panics if v's Kind is not Struct or i is out of range. +func (v Value) Field(i int) Value { + iv := v.internal() + iv.mustBe(Struct) + t := iv.typ.toType() + if i < 0 || i >= t.NumField() { + panic("reflect: Field index out of range") + } + f := t.Field(i) + + // Inherit permission bits from v. + flag := iv.flag + // Using an unexported field forces flagRO. + if f.PkgPath != "" { + flag |= flagRO } + return valueFromValueOffset(flag, f.Type, iv, f.Offset) } -// Set sets v to the value x. -func (v *ComplexValue) SetValue(x Value) { v.Set(x.(*ComplexValue).Get()) } +// valueFromValueOffset returns a sub-value of outer +// (outer is an array or a struct) with the given flag and type +// starting at the given byte offset into outer. +func valueFromValueOffset(flag uint32, typ Type, outer internalValue, offset uintptr) Value { + if outer.addr != nil { + return valueFromAddr(flag, typ, unsafe.Pointer(uintptr(outer.addr)+offset)) + } -// IntValue represents an int value. -type IntValue struct { - value "int" + // outer is so tiny it is in line. + // We have to use outer.word and derive + // the new word (it cannot possibly be bigger). + // In line, so not addressable. + if flag&flagAddr != 0 { + panic("reflect: internal error: misuse of valueFromValueOffset") + } + b := *(*[ptrSize]byte)(unsafe.Pointer(&outer.word)) + for i := uintptr(0); i < typ.Size(); i++ { + b[i] = b[offset+i] + } + for i := typ.Size(); i < ptrSize; i++ { + b[i] = 0 + } + w := *(*iword)(unsafe.Pointer(&b)) + return valueFromIword(flag, typ, w) } -// Get returns the underlying int value. -func (v *IntValue) Get() int64 { - switch v.typ.Kind() { - case Int: - return int64(*(*int)(v.addr)) - case Int8: - return int64(*(*int8)(v.addr)) - case Int16: - return int64(*(*int16)(v.addr)) - case Int32: - return int64(*(*int32)(v.addr)) - case Int64: - return *(*int64)(v.addr) +// FieldByIndex returns the nested field corresponding to index. +// It panics if v's Kind is not struct. +func (v Value) FieldByIndex(index []int) Value { + v.internal().mustBe(Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == Ptr && v.Elem().Kind() == Struct { + v = v.Elem() + } + } + v = v.Field(x) } - panic("reflect: invalid int kind") + return v } -// Set sets v to the value x. -func (v *IntValue) Set(x int64) { - if !v.CanSet() { - panic(cannotSet) +// FieldByName returns the struct field with the given name. +// It returns the zero Value if no field was found. +// It panics if v's Kind is not struct. +func (v Value) FieldByName(name string) Value { + iv := v.internal() + iv.mustBe(Struct) + if f, ok := iv.typ.FieldByName(name); ok { + return v.FieldByIndex(f.Index) + } + return Value{} +} + +// FieldByNameFunc returns the struct field with a name +// that satisfies the match function. +// It panics if v's Kind is not struct. +// It returns the zero Value if no field was found. +func (v Value) FieldByNameFunc(match func(string) bool) Value { + v.internal().mustBe(Struct) + if f, ok := v.Type().FieldByNameFunc(match); ok { + return v.FieldByIndex(f.Index) + } + return Value{} +} + +// Float returns v's underlying value, as an float64. +// It panics if v's Kind is not Float32 or Float64 +func (v Value) Float() float64 { + iv := v.internal() + switch iv.kind { + case Float32: + return float64(*(*float32)(iv.addr)) + case Float64: + return *(*float64)(iv.addr) } - switch v.typ.Kind() { + panic(&ValueError{"reflect.Value.Float", iv.kind}) +} + +// Index returns v's i'th element. +// It panics if v's Kind is not Array or Slice or i is out of range. +func (v Value) Index(i int) Value { + iv := v.internal() + switch iv.kind { default: - panic("reflect: invalid int kind") + panic(&ValueError{"reflect.Value.Index", iv.kind}) + case Array: + flag := iv.flag // element flag same as overall array + t := iv.typ.toType() + if i < 0 || i > t.Len() { + panic("reflect: array index out of range") + } + typ := t.Elem() + return valueFromValueOffset(flag, typ, iv, uintptr(i)*typ.Size()) + + case Slice: + // Element flag same as Elem of Ptr. + // Addressable, possibly read-only. + flag := iv.flag&flagRO | flagAddr + s := (*SliceHeader)(iv.addr) + if i < 0 || i >= s.Len { + panic("reflect: slice index out of range") + } + typ := iv.typ.Elem() + addr := unsafe.Pointer(s.Data + uintptr(i)*typ.Size()) + return valueFromAddr(flag, typ, addr) + } + + panic("not reached") +} + +// Int returns v's underlying value, as an int64. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64. +func (v Value) Int() int64 { + iv := v.internal() + switch iv.kind { case Int: - *(*int)(v.addr) = int(x) + return int64(*(*int)(iv.addr)) case Int8: - *(*int8)(v.addr) = int8(x) + return int64(*(*int8)(iv.addr)) case Int16: - *(*int16)(v.addr) = int16(x) + return int64(*(*int16)(iv.addr)) case Int32: - *(*int32)(v.addr) = int32(x) + return int64(*(*int32)(iv.addr)) case Int64: - *(*int64)(v.addr) = x + return *(*int64)(iv.addr) } + panic(&ValueError{"reflect.Value.Int", iv.kind}) } -// Set sets v to the value x. -func (v *IntValue) SetValue(x Value) { v.Set(x.(*IntValue).Get()) } - -// Overflow returns true if x cannot be represented by the type of v. -func (v *IntValue) Overflow(x int64) bool { - bitSize := uint(v.typ.Bits()) - trunc := (x << (64 - bitSize)) >> (64 - bitSize) - return x != trunc -} - -// StringHeader is the runtime representation of a string. -type StringHeader struct { - Data uintptr - Len int +// CanInterface returns true if Interface can be used without panicking. +func (v Value) CanInterface() bool { + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.CanInterface", iv.kind}) + } + // TODO(rsc): Check flagRO too. Decide what to do about asking for + // interface for a value obtained via an unexported field. + // If the field were of a known type, say chan int or *sync.Mutex, + // the caller could interfere with the data after getting the + // interface. But fmt.Print depends on being able to look. + // Now that reflect is more efficient the special cases in fmt + // might be less important. + return v.InternalMethod == 0 } -// StringValue represents a string value. -type StringValue struct { - value "string" +// Interface returns v's value as an interface{}. +// If v is a method obtained by invoking Value.Method +// (as opposed to Type.Method), Interface cannot return an +// interface value, so it panics. +func (v Value) Interface() interface{} { + return v.internal().Interface() } -// Get returns the underlying string value. -func (v *StringValue) Get() string { return *(*string)(v.addr) } +func (iv internalValue) Interface() interface{} { + if iv.method { + panic("reflect.Value.Interface: cannot create interface value for method with bound receiver") + } + /* + if v.flag()&noExport != 0 { + panic("reflect.Value.Interface: cannot return value obtained from unexported struct field") + } + */ -// Set sets v to the value x. -func (v *StringValue) Set(x string) { - if !v.CanSet() { - panic(cannotSet) + if iv.kind == Interface { + // Special case: return the element inside the interface. + // Won't recurse further because an interface cannot contain an interface. + if iv.IsNil() { + return nil + } + return iv.Elem().Interface() } - *(*string)(v.addr) = x -} -// Set sets v to the value x. -func (v *StringValue) SetValue(x Value) { v.Set(x.(*StringValue).Get()) } + // Non-interface value. + var eface emptyInterface + eface.typ = iv.typ.runtimeType() + eface.word = iv.word + return *(*interface{})(unsafe.Pointer(&eface)) +} -// UintValue represents a uint value. -type UintValue struct { - value "uint" +// InterfaceData returns the interface v's value as a uintptr pair. +// It panics if v's Kind is not Interface. +func (v Value) InterfaceData() [2]uintptr { + iv := v.internal() + iv.mustBe(Interface) + // We treat this as a read operation, so we allow + // it even for unexported data, because the caller + // has to import "unsafe" to turn it into something + // that can be abused. + return *(*[2]uintptr)(iv.addr) } -// Get returns the underlying uuint value. -func (v *UintValue) Get() uint64 { - switch v.typ.Kind() { - case Uint: - return uint64(*(*uint)(v.addr)) - case Uint8: - return uint64(*(*uint8)(v.addr)) - case Uint16: - return uint64(*(*uint16)(v.addr)) - case Uint32: - return uint64(*(*uint32)(v.addr)) - case Uint64: - return *(*uint64)(v.addr) - case Uintptr: - return uint64(*(*uintptr)(v.addr)) - } - panic("reflect: invalid uint kind") +// IsNil returns true if v is a nil value. +// It panics if v's Kind is not Chan, Func, Interface, Map, Ptr, or Slice. +func (v Value) IsNil() bool { + return v.internal().IsNil() } -// Set sets v to the value x. -func (v *UintValue) Set(x uint64) { - if !v.CanSet() { - panic(cannotSet) - } - switch v.typ.Kind() { - default: - panic("reflect: invalid uint kind") - case Uint: - *(*uint)(v.addr) = uint(x) - case Uint8: - *(*uint8)(v.addr) = uint8(x) - case Uint16: - *(*uint16)(v.addr) = uint16(x) - case Uint32: - *(*uint32)(v.addr) = uint32(x) - case Uint64: - *(*uint64)(v.addr) = x - case Uintptr: - *(*uintptr)(v.addr) = uintptr(x) +func (iv internalValue) IsNil() bool { + switch iv.kind { + case Ptr: + if iv.method { + panic("reflect: IsNil of method Value") + } + return iv.word == 0 + case Chan, Func, Map: + if iv.method { + panic("reflect: IsNil of method Value") + } + return *(*uintptr)(iv.addr) == 0 + case Interface, Slice: + // Both interface and slice are nil if first word is 0. + return *(*uintptr)(iv.addr) == 0 } + panic(&ValueError{"reflect.Value.IsNil", iv.kind}) } -// Overflow returns true if x cannot be represented by the type of v. -func (v *UintValue) Overflow(x uint64) bool { - bitSize := uint(v.typ.Bits()) - trunc := (x << (64 - bitSize)) >> (64 - bitSize) - return x != trunc +// IsValid returns true if v represents a value. +// It returns false if v is the zero Value. +// If IsValid returns false, all other methods except String panic. +// Most functions and methods never return an invalid value. +// If one does, its documentation states the conditions explicitly. +func (v Value) IsValid() bool { + return v.Internal != nil } -// Set sets v to the value x. -func (v *UintValue) SetValue(x Value) { v.Set(x.(*UintValue).Get()) } +// Kind returns v's Kind. +// If v is the zero Value (IsValid returns false), Kind returns Invalid. +func (v Value) Kind() Kind { + return v.internal().kind +} -// UnsafePointerValue represents an unsafe.Pointer value. -type UnsafePointerValue struct { - value "unsafe.Pointer" +// Len returns v's length. +// It panics if v's Kind is not Array, Chan, Map, or Slice. +func (v Value) Len() int { + iv := v.internal() + switch iv.kind { + case Array: + return iv.typ.Len() + case Chan: + return int(chanlen(*(*iword)(iv.addr))) + case Map: + return int(maplen(*(*iword)(iv.addr))) + case Slice: + return (*SliceHeader)(iv.addr).Len + } + panic(&ValueError{"reflect.Value.Len", iv.kind}) } -// Get returns the underlying uintptr value. -// Get returns uintptr, not unsafe.Pointer, so that -// programs that do not import "unsafe" cannot -// obtain a value of unsafe.Pointer type from "reflect". -func (v *UnsafePointerValue) Get() uintptr { return uintptr(*(*unsafe.Pointer)(v.addr)) } +// MapIndex returns the value associated with key in the map v. +// It panics if v's Kind is not Map. +// It returns the zero Value if key is not found in the map or if v represents a nil map. +// As in Go, the key's value must be assignable to the map's key type. +func (v Value) MapIndex(key Value) Value { + iv := v.internal() + iv.mustBe(Map) + typ := iv.typ.toType() -// Set sets v to the value x. -func (v *UnsafePointerValue) Set(x unsafe.Pointer) { - if !v.CanSet() { - panic(cannotSet) + // Do not require ikey to be exported, so that DeepEqual + // and other programs can use all the keys returned by + // MapKeys as arguments to MapIndex. If either the map + // or the key is unexported, though, the result will be + // considered unexported. + + ikey := key.internal() + ikey = convertForAssignment("reflect.Value.MapIndex", nil, typ.Key(), ikey) + if iv.word == 0 { + return Value{} } - *(*unsafe.Pointer)(v.addr) = x -} -// Set sets v to the value x. -func (v *UnsafePointerValue) SetValue(x Value) { - v.Set(unsafe.Pointer(x.(*UnsafePointerValue).Get())) + flag := (iv.flag | ikey.flag) & flagRO + elemType := typ.Elem() + elemWord, ok := mapaccess(*(*iword)(iv.addr), ikey.word) + if !ok { + return Value{} + } + return valueFromIword(flag, elemType, elemWord) } -func typesMustMatch(t1, t2 Type) { - if t1 != t2 { - panic("type mismatch: " + t1.String() + " != " + t2.String()) +// MapKeys returns a slice containing all the keys present in the map, +// in unspecified order. +// It panics if v's Kind is not Map. +// It returns an empty slice if v represents a nil map. +func (v Value) MapKeys() []Value { + iv := v.internal() + iv.mustBe(Map) + keyType := iv.typ.Key() + + flag := iv.flag & flagRO + m := *(*iword)(iv.addr) + mlen := int32(0) + if m != 0 { + mlen = maplen(m) + } + it := mapiterinit(m) + a := make([]Value, mlen) + var i int + for i = 0; i < len(a); i++ { + keyWord, ok := mapiterkey(it) + if !ok { + break + } + a[i] = valueFromIword(flag, keyType, keyWord) + mapiternext(it) } + return a[:i] } -/* - * array - */ +// Method returns a function value corresponding to v's i'th method. +// The arguments to a Call on the returned function should not include +// a receiver; the returned function will always use v as the receiver. +// Method panics if i is out of range. +func (v Value) Method(i int) Value { + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.Method", Invalid}) + } + if i < 0 || i >= iv.typ.NumMethod() { + panic("reflect: Method index out of range") + } + return Value{v.Internal, i + 1} +} -// ArrayOrSliceValue is the common interface -// implemented by both ArrayValue and SliceValue. -type ArrayOrSliceValue interface { - Value - Len() int - Cap() int - Elem(i int) Value - addr() addr +// NumField returns the number of fields in the struct v. +// It panics if v's Kind is not Struct. +func (v Value) NumField() int { + iv := v.internal() + iv.mustBe(Struct) + return iv.typ.NumField() } -// grow grows the slice s so that it can hold extra more values, allocating -// more capacity if needed. It also returns the old and new slice lengths. -func grow(s *SliceValue, extra int) (*SliceValue, int, int) { - i0 := s.Len() - i1 := i0 + extra - if i1 < i0 { - panic("append: slice overflow") - } - m := s.Cap() - if i1 <= m { - return s.Slice(0, i1), i0, i1 +// OverflowComplex returns true if the complex128 x cannot be represented by v's type. +// It panics if v's Kind is not Complex64 or Complex128. +func (v Value) OverflowComplex(x complex128) bool { + iv := v.internal() + switch iv.kind { + case Complex64: + return overflowFloat32(real(x)) || overflowFloat32(imag(x)) + case Complex128: + return false } - if m == 0 { - m = extra - } else { - for m < i1 { - if i0 < 1024 { - m += m - } else { - m += m / 4 - } - } + panic(&ValueError{"reflect.Value.OverflowComplex", iv.kind}) +} + +// OverflowFloat returns true if the float64 x cannot be represented by v's type. +// It panics if v's Kind is not Float32 or Float64. +func (v Value) OverflowFloat(x float64) bool { + iv := v.internal() + switch iv.kind { + case Float32: + return overflowFloat32(x) + case Float64: + return false } - t := MakeSlice(s.Type().(*SliceType), i1, m) - Copy(t, s) - return t, i0, i1 + panic(&ValueError{"reflect.Value.OverflowFloat", iv.kind}) } -// Append appends the values x to a slice s and returns the resulting slice. -// Each x must have the same type as s' element type. -func Append(s *SliceValue, x ...Value) *SliceValue { - s, i0, i1 := grow(s, len(x)) - for i, j := i0, 0; i < i1; i, j = i+1, j+1 { - s.Elem(i).SetValue(x[j]) +func overflowFloat32(x float64) bool { + if x < 0 { + x = -x } - return s + return math.MaxFloat32 <= x && x <= math.MaxFloat64 +} + +// OverflowInt returns true if the int64 x cannot be represented by v's type. +// It panics if v's Kind is not Int, Int8, int16, Int32, or Int64. +func (v Value) OverflowInt(x int64) bool { + iv := v.internal() + switch iv.kind { + case Int, Int8, Int16, Int32, Int64: + bitSize := iv.typ.size * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{"reflect.Value.OverflowInt", iv.kind}) +} + +// OverflowUint returns true if the uint64 x cannot be represented by v's type. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. +func (v Value) OverflowUint(x uint64) bool { + iv := v.internal() + switch iv.kind { + case Uint, Uintptr, Uint8, Uint16, Uint32, Uint64: + bitSize := iv.typ.size * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{"reflect.Value.OverflowUint", iv.kind}) +} + +// Pointer returns v's value as a uintptr. +// It returns uintptr instead of unsafe.Pointer so that +// code using reflect cannot obtain unsafe.Pointers +// without importing the unsafe package explicitly. +// It panics if v's Kind is not Chan, Func, Map, Ptr, Slice, or UnsafePointer. +func (v Value) Pointer() uintptr { + iv := v.internal() + switch iv.kind { + case Ptr, UnsafePointer: + if iv.kind == Func && v.InternalMethod != 0 { + panic("reflect.Value.Pointer of method Value") + } + return uintptr(iv.word) + case Chan, Func, Map: + if iv.kind == Func && v.InternalMethod != 0 { + panic("reflect.Value.Pointer of method Value") + } + return *(*uintptr)(iv.addr) + case Slice: + return (*SliceHeader)(iv.addr).Data + } + panic(&ValueError{"reflect.Value.Pointer", iv.kind}) } -// AppendSlice appends a slice t to a slice s and returns the resulting slice. -// The slices s and t must have the same element type. -func AppendSlice(s, t *SliceValue) *SliceValue { - s, i0, i1 := grow(s, t.Len()) - Copy(s.Slice(i0, i1), t) - return s +// Recv receives and returns a value from the channel v. +// It panics if v's Kind is not Chan. +// The receive blocks until a value is ready. +// The boolean value ok is true if the value x corresponds to a send +// on the channel, false if it is a zero value received because the channel is closed. +func (v Value) Recv() (x Value, ok bool) { + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.recv(false) } -// Copy copies the contents of src into dst until either -// dst has been filled or src has been exhausted. -// It returns the number of elements copied. -// The arrays dst and src must have the same element type. -func Copy(dst, src ArrayOrSliceValue) int { - // TODO: This will have to move into the runtime - // once the real gc goes in. - de := dst.Type().(ArrayOrSliceType).Elem() - se := src.Type().(ArrayOrSliceType).Elem() - typesMustMatch(de, se) - n := dst.Len() - if xn := src.Len(); n > xn { - n = xn +// internal recv, possibly non-blocking (nb) +func (iv internalValue) recv(nb bool) (val Value, ok bool) { + t := iv.typ.toType() + if t.ChanDir()&RecvDir == 0 { + panic("recv on send-only channel") } - memmove(dst.addr(), src.addr(), uintptr(n)*de.Size()) - return n + ch := *(*iword)(iv.addr) + if ch == 0 { + panic("recv on nil channel") + } + valWord, selected, ok := chanrecv(ch, nb) + if selected { + val = valueFromIword(0, t.Elem(), valWord) + } + return } -// An ArrayValue represents an array. -type ArrayValue struct { - value "array" +// Send sends x on the channel v. +// It panics if v's kind is not Chan or if x's type is not the same type as v's element type. +// As in Go, x's value must be assignable to the channel's element type. +func (v Value) Send(x Value) { + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + iv.send(x, false) +} + +// internal send, possibly non-blocking +func (iv internalValue) send(x Value, nb bool) (selected bool) { + t := iv.typ.toType() + if t.ChanDir()&SendDir == 0 { + panic("send on recv-only channel") + } + ix := x.internal() + ix.mustBeExported() // do not let unexported x leak + ix = convertForAssignment("reflect.Value.Send", nil, t.Elem(), ix) + ch := *(*iword)(iv.addr) + if ch == 0 { + panic("send on nil channel") + } + return chansend(ch, ix.word, nb) } -// Len returns the length of the array. -func (v *ArrayValue) Len() int { return v.typ.(*ArrayType).Len() } +// Set assigns x to the value v. +// It panics if CanSet returns false. +// As in Go, x's value must be assignable to v's type. +func (v Value) Set(x Value) { + iv := v.internal() + ix := x.internal() -// Cap returns the capacity of the array (equal to Len()). -func (v *ArrayValue) Cap() int { return v.typ.(*ArrayType).Len() } + iv.mustBeAssignable() + ix.mustBeExported() // do not let unexported x leak -// addr returns the base address of the data in the array. -func (v *ArrayValue) addr() addr { return v.value.addr } + ix = convertForAssignment("reflect.Set", iv.addr, iv.typ, ix) -// Set assigns x to v. -// The new value x must have the same type as v. -func (v *ArrayValue) Set(x *ArrayValue) { - if !v.CanSet() { - panic(cannotSet) + n := ix.typ.size + if Kind(ix.typ.kind) == Ptr || Kind(ix.typ.kind) == UnsafePointer { + storeIword(iv.addr, ix.word, n) + } else { + memmove(iv.addr, ix.addr, n) } - typesMustMatch(v.typ, x.typ) - Copy(v, x) } -// Set sets v to the value x. -func (v *ArrayValue) SetValue(x Value) { v.Set(x.(*ArrayValue)) } +// SetBool sets v's underlying value. +// It panics if v's Kind is not Bool or if CanSet() is false. +func (v Value) SetBool(x bool) { + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(Bool) + *(*bool)(iv.addr) = x +} -// Elem returns the i'th element of v. -func (v *ArrayValue) Elem(i int) Value { - typ := v.typ.(*ArrayType).Elem() - n := v.Len() - if i < 0 || i >= n { - panic("array index out of bounds") +// SetComplex sets v's underlying value to x. +// It panics if v's Kind is not Complex64 or Complex128, or if CanSet() is false. +func (v Value) SetComplex(x complex128) { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { + default: + panic(&ValueError{"reflect.Value.SetComplex", iv.kind}) + case Complex64: + *(*complex64)(iv.addr) = complex64(x) + case Complex128: + *(*complex128)(iv.addr) = x } - p := addr(uintptr(v.addr()) + uintptr(i)*typ.Size()) - return newValue(typ, p, v.flag) } -/* - * slice - */ - -// runtime representation of slice -type SliceHeader struct { - Data uintptr - Len int - Cap int +// SetFloat sets v's underlying value to x. +// It panics if v's Kind is not Float32 or Float64, or if CanSet() is false. +func (v Value) SetFloat(x float64) { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { + default: + panic(&ValueError{"reflect.Value.SetFloat", iv.kind}) + case Float32: + *(*float32)(iv.addr) = float32(x) + case Float64: + *(*float64)(iv.addr) = x + } } -// A SliceValue represents a slice. -type SliceValue struct { - value "slice" +// SetInt sets v's underlying value to x. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64, or if CanSet() is false. +func (v Value) SetInt(x int64) { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { + default: + panic(&ValueError{"reflect.Value.SetInt", iv.kind}) + case Int: + *(*int)(iv.addr) = int(x) + case Int8: + *(*int8)(iv.addr) = int8(x) + case Int16: + *(*int16)(iv.addr) = int16(x) + case Int32: + *(*int32)(iv.addr) = int32(x) + case Int64: + *(*int64)(iv.addr) = x + } } -func (v *SliceValue) slice() *SliceHeader { return (*SliceHeader)(v.value.addr) } - -// IsNil returns whether v is a nil slice. -func (v *SliceValue) IsNil() bool { return v.slice().Data == 0 } - -// Len returns the length of the slice. -func (v *SliceValue) Len() int { return int(v.slice().Len) } - -// Cap returns the capacity of the slice. -func (v *SliceValue) Cap() int { return int(v.slice().Cap) } - -// addr returns the base address of the data in the slice. -func (v *SliceValue) addr() addr { return addr(v.slice().Data) } - -// SetLen changes the length of v. -// The new length n must be between 0 and the capacity, inclusive. -func (v *SliceValue) SetLen(n int) { - s := v.slice() +// SetLen sets v's length to n. +// It panics if v's Kind is not Slice. +func (v Value) SetLen(n int) { + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(Slice) + s := (*SliceHeader)(iv.addr) if n < 0 || n > int(s.Cap) { panic("reflect: slice length out of range in SetLen") } s.Len = n } -// Set assigns x to v. -// The new value x must have the same type as v. -func (v *SliceValue) Set(x *SliceValue) { - if !v.CanSet() { - panic(cannotSet) - } - typesMustMatch(v.typ, x.typ) - *v.slice() = *x.slice() -} +// SetMapIndex sets the value associated with key in the map v to val. +// It panics if v's Kind is not Map. +// If val is the zero Value, SetMapIndex deletes the key from the map. +// As in Go, key's value must be assignable to the map's key type, +// and val's value must be assignable to the map's value type. +func (v Value) SetMapIndex(key, val Value) { + iv := v.internal() + ikey := key.internal() + ival := val.internal() -// Set sets v to the value x. -func (v *SliceValue) SetValue(x Value) { v.Set(x.(*SliceValue)) } + iv.mustBe(Map) + iv.mustBeExported() -// Get returns the uintptr address of the v.Cap()'th element. This gives -// the same result for all slices of the same array. -// It is mainly useful for printing. -func (v *SliceValue) Get() uintptr { - typ := v.typ.(*SliceType) - return uintptr(v.addr()) + uintptr(v.Cap())*typ.Elem().Size() -} + ikey.mustBeExported() + ikey = convertForAssignment("reflect.Value.SetMapIndex", nil, iv.typ.Key(), ikey) -// Slice returns a sub-slice of the slice v. -func (v *SliceValue) Slice(beg, end int) *SliceValue { - cap := v.Cap() - if beg < 0 || end < beg || end > cap { - panic("slice index out of bounds") + if ival.kind != Invalid { + ival.mustBeExported() + ival = convertForAssignment("reflect.Value.SetMapIndex", nil, iv.typ.Elem(), ival) } - typ := v.typ.(*SliceType) - s := new(SliceHeader) - s.Data = uintptr(v.addr()) + uintptr(beg)*typ.Elem().Size() - s.Len = end - beg - s.Cap = cap - beg - // Like the result of Addr, we treat Slice as an - // unaddressable temporary, so don't set canAddr. - flag := canSet - if v.flag&canStore != 0 { - flag |= canStore - } - return newValue(typ, addr(s), flag).(*SliceValue) + mapassign(*(*iword)(iv.addr), ikey.word, ival.word, ival.kind != Invalid) } -// Elem returns the i'th element of v. -func (v *SliceValue) Elem(i int) Value { - typ := v.typ.(*SliceType).Elem() - n := v.Len() - if i < 0 || i >= n { - panic("reflect: slice index out of range") - } - p := addr(uintptr(v.addr()) + uintptr(i)*typ.Size()) - flag := canAddr - if v.flag&canStore != 0 { - flag |= canSet | canStore +// SetUint sets v's underlying value to x. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64, or if CanSet() is false. +func (v Value) SetUint(x uint64) { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { + default: + panic(&ValueError{"reflect.Value.SetUint", iv.kind}) + case Uint: + *(*uint)(iv.addr) = uint(x) + case Uint8: + *(*uint8)(iv.addr) = uint8(x) + case Uint16: + *(*uint16)(iv.addr) = uint16(x) + case Uint32: + *(*uint32)(iv.addr) = uint32(x) + case Uint64: + *(*uint64)(iv.addr) = x + case Uintptr: + *(*uintptr)(iv.addr) = uintptr(x) } - return newValue(typ, p, flag) } -// MakeSlice creates a new zero-initialized slice value -// for the specified slice type, length, and capacity. -func MakeSlice(typ *SliceType, len, cap int) *SliceValue { - s := &SliceHeader{ - Data: uintptr(unsafe.NewArray(typ.Elem(), cap)), - Len: len, - Cap: cap, - } - return newValue(typ, addr(s), canAddr|canSet|canStore).(*SliceValue) +// SetPointer sets the unsafe.Pointer value v to x. +// It panics if v's Kind is not UnsafePointer. +func (v Value) SetPointer(x unsafe.Pointer) { + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(UnsafePointer) + *(*unsafe.Pointer)(iv.addr) = x } -/* - * chan - */ - -// A ChanValue represents a chan. -type ChanValue struct { - value "chan" +// SetString sets v's underlying value to x. +// It panics if v's Kind is not String or if CanSet() is false. +func (v Value) SetString(x string) { + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(String) + *(*string)(iv.addr) = x } -// IsNil returns whether v is a nil channel. -func (v *ChanValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } - -// Set assigns x to v. -// The new value x must have the same type as v. -func (v *ChanValue) Set(x *ChanValue) { - if !v.CanSet() { - panic(cannotSet) +// Slice returns a slice of v. +// It panics if v's Kind is not Array or Slice. +func (v Value) Slice(beg, end int) Value { + iv := v.internal() + if iv.kind != Array && iv.kind != Slice { + panic(&ValueError{"reflect.Value.Slice", iv.kind}) } - typesMustMatch(v.typ, x.typ) - *(*uintptr)(v.addr) = *(*uintptr)(x.addr) -} - -// Set sets v to the value x. -func (v *ChanValue) SetValue(x Value) { v.Set(x.(*ChanValue)) } - -// Get returns the uintptr value of v. -// It is mainly useful for printing. -func (v *ChanValue) Get() uintptr { return *(*uintptr)(v.addr) } - -// implemented in ../pkg/runtime/reflect.cgo -func makechan(typ *runtime.ChanType, size uint32) (ch *byte) -func chansend(ch, val *byte, selected *bool) -func chanrecv(ch, val *byte, selected *bool, ok *bool) -func chanclose(ch *byte) -func chanlen(ch *byte) int32 -func chancap(ch *byte) int32 - -// Close closes the channel. -func (v *ChanValue) Close() { - ch := *(**byte)(v.addr) - chanclose(ch) -} - -func (v *ChanValue) Len() int { - ch := *(**byte)(v.addr) - return int(chanlen(ch)) -} - -func (v *ChanValue) Cap() int { - ch := *(**byte)(v.addr) - return int(chancap(ch)) -} - -// internal send; non-blocking if selected != nil -func (v *ChanValue) send(x Value, selected *bool) { - t := v.Type().(*ChanType) - if t.Dir()&SendDir == 0 { - panic("send on recv-only channel") + cap := v.Cap() + if beg < 0 || end < beg || end > cap { + panic("reflect.Value.Slice: slice index out of bounds") + } + var typ Type + var base uintptr + switch iv.kind { + case Array: + if iv.flag&flagAddr == 0 { + panic("reflect.Value.Slice: slice of unaddressable array") + } + typ = toType((*arrayType)(unsafe.Pointer(iv.typ)).slice) + base = uintptr(iv.addr) + case Slice: + typ = iv.typ.toType() + base = (*SliceHeader)(iv.addr).Data } - typesMustMatch(t.Elem(), x.Type()) - ch := *(**byte)(v.addr) - chansend(ch, (*byte)(x.getAddr()), selected) + s := new(SliceHeader) + s.Data = base + uintptr(beg)*typ.Elem().Size() + s.Len = end - beg + s.Cap = cap - beg + return valueFromAddr(iv.flag&flagRO, typ, unsafe.Pointer(s)) } -// internal recv; non-blocking if selected != nil -func (v *ChanValue) recv(selected *bool) (Value, bool) { - t := v.Type().(*ChanType) - if t.Dir()&RecvDir == 0 { - panic("recv on send-only channel") +// String returns the string v's underlying value, as a string. +// String is a special case because of Go's String method convention. +// Unlike the other getters, it does not panic if v's Kind is not String. +// Instead, it returns a string of the form "" where T is v's type. +func (v Value) String() string { + iv := v.internal() + switch iv.kind { + case Invalid: + return "" + case String: + return *(*string)(iv.addr) } - ch := *(**byte)(v.addr) - x := MakeZero(t.Elem()) - var ok bool - chanrecv(ch, (*byte)(x.getAddr()), selected, &ok) - return x, ok + return "<" + iv.typ.String() + " Value>" } -// Send sends x on the channel v. -func (v *ChanValue) Send(x Value) { v.send(x, nil) } - -// Recv receives and returns a value from the channel v. -// The receive blocks until a value is ready. -// The boolean value ok is true if the value x corresponds to a send +// TryRecv attempts to receive a value from the channel v but will not block. +// It panics if v's Kind is not Chan. +// If the receive cannot finish without blocking, x is the zero Value. +// The boolean ok is true if the value x corresponds to a send // on the channel, false if it is a zero value received because the channel is closed. -func (v *ChanValue) Recv() (x Value, ok bool) { - return v.recv(nil) +func (v Value) TryRecv() (x Value, ok bool) { + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.recv(true) } -// TrySend attempts to sends x on the channel v but will not block. +// TrySend attempts to send x on the channel v but will not block. +// It panics if v's Kind is not Chan. // It returns true if the value was sent, false otherwise. -func (v *ChanValue) TrySend(x Value) bool { - var selected bool - v.send(x, &selected) - return selected +// As in Go, x's value must be assignable to the channel's element type. +func (v Value) TrySend(x Value) bool { + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.send(x, true) } -// TryRecv attempts to receive a value from the channel v but will not block. -// If the receive cannot finish without blocking, TryRecv instead returns x == nil. -// If the receive can finish without blocking, TryRecv returns x != nil. -// The boolean value ok is true if the value x corresponds to a send -// on the channel, false if it is a zero value received because the channel is closed. -func (v *ChanValue) TryRecv() (x Value, ok bool) { - var selected bool - x, ok = v.recv(&selected) - if !selected { - return nil, false +// Type returns v's type. +func (v Value) Type() Type { + t := v.internal().typ + if t == nil { + panic(&ValueError{"reflect.Value.Type", Invalid}) } - return x, ok + return t.toType() } -// MakeChan creates a new channel with the specified type and buffer size. -func MakeChan(typ *ChanType, buffer int) *ChanValue { - if buffer < 0 { - panic("MakeChan: negative buffer size") - } - if typ.Dir() != BothDir { - panic("MakeChan: unidirectional channel type") +// Uint returns v's underlying value, as a uint64. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. +func (v Value) Uint() uint64 { + iv := v.internal() + switch iv.kind { + case Uint: + return uint64(*(*uint)(iv.addr)) + case Uint8: + return uint64(*(*uint8)(iv.addr)) + case Uint16: + return uint64(*(*uint16)(iv.addr)) + case Uint32: + return uint64(*(*uint32)(iv.addr)) + case Uintptr: + return uint64(*(*uintptr)(iv.addr)) + case Uint64: + return *(*uint64)(iv.addr) } - v := MakeZero(typ).(*ChanValue) - *(**byte)(v.addr) = makechan((*runtime.ChanType)(unsafe.Pointer(typ)), uint32(buffer)) - return v + panic(&ValueError{"reflect.Value.Uint", iv.kind}) } -/* - * func - */ - -// A FuncValue represents a function value. -type FuncValue struct { - value "func" - first *value - isInterface bool -} - -// IsNil returns whether v is a nil function. -func (v *FuncValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } - -// Get returns the uintptr value of v. -// It is mainly useful for printing. -func (v *FuncValue) Get() uintptr { return *(*uintptr)(v.addr) } - -// Set assigns x to v. -// The new value x must have the same type as v. -func (v *FuncValue) Set(x *FuncValue) { - if !v.CanSet() { - panic(cannotSet) +// UnsafeAddr returns a pointer to v's data. +// It is for advanced clients that also import the "unsafe" package. +// It panics if v is not addressable. +func (v Value) UnsafeAddr() uintptr { + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.UnsafeAddr", iv.kind}) + } + if iv.flag&flagAddr == 0 { + panic("reflect.Value.UnsafeAddr of unaddressable value") } - typesMustMatch(v.typ, x.typ) - *(*uintptr)(v.addr) = *(*uintptr)(x.addr) + return uintptr(iv.addr) } -// Set sets v to the value x. -func (v *FuncValue) SetValue(x Value) { v.Set(x.(*FuncValue)) } - -// Method returns a FuncValue corresponding to v's i'th method. -// The arguments to a Call on the returned FuncValue -// should not include a receiver; the FuncValue will use v -// as the receiver. -func (v *value) Method(i int) *FuncValue { - t := v.Type().uncommon() - if t == nil || i < 0 || i >= len(t.methods) { - return nil - } - p := &t.methods[i] - fn := p.tfn - fv := &FuncValue{value: value{runtimeToType(p.typ), addr(&fn), 0}, first: v, isInterface: false} - return fv +// StringHeader is the runtime representation of a string. +// It cannot be used safely or portably. +type StringHeader struct { + Data uintptr + Len int } -// implemented in ../pkg/runtime/*/asm.s -func call(typ *FuncType, fnaddr *byte, isInterface bool, params *addr, results *addr) +// SliceHeader is the runtime representation of a slice. +// It cannot be used safely or portably. +type SliceHeader struct { + Data uintptr + Len int + Cap int +} -// Interface returns the fv as an interface value. -// If fv is a method obtained by invoking Value.Method -// (as opposed to Type.Method), Interface cannot return an -// interface value, so it panics. -func (fv *FuncValue) Interface() interface{} { - if fv.first != nil { - panic("FuncValue: cannot create interface value for method with bound receiver") +func typesMustMatch(what string, t1, t2 Type) { + if t1 != t2 { + panic("reflect: " + what + ": " + t1.String() + " != " + t2.String()) } - return fv.value.Interface() } -// Call calls the function fv with input parameters in. -// It returns the function's output parameters as Values. -func (fv *FuncValue) Call(in []Value) []Value { - t := fv.Type().(*FuncType) - nin := len(in) - if fv.first != nil && !fv.isInterface { - nin++ - } - if nin != t.NumIn() { - panic("FuncValue: wrong argument count") +// grow grows the slice s so that it can hold extra more values, allocating +// more capacity if needed. It also returns the old and new slice lengths. +func grow(s Value, extra int) (Value, int, int) { + i0 := s.Len() + i1 := i0 + extra + if i1 < i0 { + panic("reflect.Append: slice overflow") } - if fv.first != nil && fv.isInterface { - nin++ + m := s.Cap() + if i1 <= m { + return s.Slice(0, i1), i0, i1 } - nout := t.NumOut() - - params := make([]addr, nin) - delta := 0 - off := 0 - if v := fv.first; v != nil { - // Hard-wired first argument. - if fv.isInterface { - // v is a single uninterpreted word - params[0] = v.getAddr() - } else { - // v is a real value - tv := v.Type() - - // This is a method, so we need to always pass - // a pointer. - vAddr := v.getAddr() - if ptv, ok := tv.(*PtrType); ok { - typesMustMatch(t.In(0), tv) + if m == 0 { + m = extra + } else { + for m < i1 { + if i0 < 1024 { + m += m } else { - p := addr(new(addr)) - *(*addr)(p) = vAddr - vAddr = p - typesMustMatch(t.In(0).(*PtrType).Elem(), tv) - } - - params[0] = vAddr - delta = 1 - } - off = 1 - } - for i, v := range in { - tv := v.Type() - tf := t.In(i + delta) - - // If this is really a method, and we are explicitly - // passing the object, then we need to pass the address - // of the object instead. Unfortunately, we don't - // have any way to know that this is a method, so we just - // check the type. FIXME: This is ugly. - vAddr := v.getAddr() - if i == 0 && tf != tv { - if ptf, ok := tf.(*PtrType); ok { - p := addr(new(addr)) - *(*addr)(p) = vAddr - vAddr = p - tf = ptf.Elem() + m += m / 4 } } - - typesMustMatch(tf, tv) - params[i+off] = vAddr - } - - ret := make([]Value, nout) - results := make([]addr, nout) - for i := 0; i < nout; i++ { - tv := t.Out(i) - v := MakeZero(tv) - results[i] = v.getAddr() - ret[i] = v } - - call(t, *(**byte)(fv.addr), fv.isInterface, ¶ms[0], &results[0]) - - return ret -} - -/* - * interface - */ - -// An InterfaceValue represents an interface value. -type InterfaceValue struct { - value "interface" -} - -// IsNil returns whether v is a nil interface value. -func (v *InterfaceValue) IsNil() bool { return v.Interface() == nil } - -// No single uinptr Get because v.Interface() is available. - -// Get returns the two words that represent an interface in the runtime. -// Those words are useful only when playing unsafe games. -func (v *InterfaceValue) Get() [2]uintptr { - return *(*[2]uintptr)(v.addr) -} - -// Elem returns the concrete value stored in the interface value v. -func (v *InterfaceValue) Elem() Value { return NewValue(v.Interface()) } - -// ../runtime/reflect.cgo -func setiface(typ *InterfaceType, x *interface{}, addr addr) - -// Set assigns x to v. -func (v *InterfaceValue) Set(x Value) { - var i interface{} - if x != nil { - i = x.Interface() - } - if !v.CanSet() { - panic(cannotSet) - } - // Two different representations; see comment in Get. - // Empty interface is easy. - t := v.typ.(*InterfaceType) - if t.NumMethod() == 0 { - *(*interface{})(v.addr) = i - return - } - - // Non-empty interface requires a runtime check. - setiface(t, &i, v.addr) + t := MakeSlice(s.Type(), i1, m) + Copy(t, s) + return t, i0, i1 } -// Set sets v to the value x. -func (v *InterfaceValue) SetValue(x Value) { v.Set(x) } - -// Method returns a FuncValue corresponding to v's i'th method. -// The arguments to a Call on the returned FuncValue -// should not include a receiver; the FuncValue will use v -// as the receiver. -func (v *InterfaceValue) Method(i int) *FuncValue { - t := v.Type().(*InterfaceType) - if t == nil || i < 0 || i >= len(t.methods) { - return nil +// Append appends the values x to a slice s and returns the resulting slice. +// As in Go, each x's value must be assignable to the slice's element type. +func Append(s Value, x ...Value) Value { + s.internal().mustBe(Slice) + s, i0, i1 := grow(s, len(x)) + for i, j := i0, 0; i < i1; i, j = i+1, j+1 { + s.Index(i).Set(x[j]) } - p := &t.methods[i] - - // Interface is two words: itable, data. - tab := *(**[10000]addr)(v.addr) - data := &value{Typeof((*byte)(nil)), addr(uintptr(v.addr) + ptrSize), 0} - - fn := tab[i+1] - fv := &FuncValue{value: value{runtimeToType(p.typ), addr(&fn), 0}, first: data, isInterface: true} - return fv + return s } -/* - * map - */ - -// A MapValue represents a map value. -type MapValue struct { - value "map" +// AppendSlice appends a slice t to a slice s and returns the resulting slice. +// The slices s and t must have the same element type. +func AppendSlice(s, t Value) Value { + s.internal().mustBe(Slice) + t.internal().mustBe(Slice) + typesMustMatch("reflect.AppendSlice", s.Type().Elem(), t.Type().Elem()) + s, i0, i1 := grow(s, t.Len()) + Copy(s.Slice(i0, i1), t) + return s } -// IsNil returns whether v is a nil map value. -func (v *MapValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } +// Copy copies the contents of src into dst until either +// dst has been filled or src has been exhausted. +// It returns the number of elements copied. +// Dst and src each must have kind Slice or Array, and +// dst and src must have the same element type. +func Copy(dst, src Value) int { + idst := dst.internal() + isrc := src.internal() -// Set assigns x to v. -// The new value x must have the same type as v. -func (v *MapValue) Set(x *MapValue) { - if !v.CanSet() { - panic(cannotSet) + if idst.kind != Array && idst.kind != Slice { + panic(&ValueError{"reflect.Copy", idst.kind}) } - if x == nil { - *(**uintptr)(v.addr) = nil - return + if idst.kind == Array { + idst.mustBeAssignable() } - typesMustMatch(v.typ, x.typ) - *(*uintptr)(v.addr) = *(*uintptr)(x.addr) -} - -// Set sets v to the value x. -func (v *MapValue) SetValue(x Value) { - if x == nil { - v.Set(nil) - return + idst.mustBeExported() + if isrc.kind != Array && isrc.kind != Slice { + panic(&ValueError{"reflect.Copy", isrc.kind}) } - v.Set(x.(*MapValue)) -} + isrc.mustBeExported() -// Get returns the uintptr value of v. -// It is mainly useful for printing. -func (v *MapValue) Get() uintptr { return *(*uintptr)(v.addr) } + de := idst.typ.Elem() + se := isrc.typ.Elem() + typesMustMatch("reflect.Copy", de, se) -// implemented in ../pkg/runtime/reflect.cgo -func mapaccess(m, key, val *byte) bool -func mapassign(m, key, val *byte) -func maplen(m *byte) int32 -func mapiterinit(m *byte) *byte -func mapiternext(it *byte) -func mapiterkey(it *byte, key *byte) bool -func makemap(t *runtime.MapType) *byte - -// Elem returns the value associated with key in the map v. -// It returns nil if key is not found in the map. -func (v *MapValue) Elem(key Value) Value { - t := v.Type().(*MapType) - typesMustMatch(t.Key(), key.Type()) - m := *(**byte)(v.addr) - if m == nil { - return nil - } - newval := MakeZero(t.Elem()) - if !mapaccess(m, (*byte)(key.getAddr()), (*byte)(newval.getAddr())) { - return nil - } - return newval -} - -// SetElem sets the value associated with key in the map v to val. -// If val is nil, Put deletes the key from map. -func (v *MapValue) SetElem(key, val Value) { - t := v.Type().(*MapType) - typesMustMatch(t.Key(), key.Type()) - var vaddr *byte - if val != nil { - typesMustMatch(t.Elem(), val.Type()) - vaddr = (*byte)(val.getAddr()) + n := dst.Len() + if sn := src.Len(); n > sn { + n = sn } - m := *(**byte)(v.addr) - mapassign(m, (*byte)(key.getAddr()), vaddr) -} -// Len returns the number of keys in the map v. -func (v *MapValue) Len() int { - m := *(**byte)(v.addr) - if m == nil { - return 0 + // If sk is an in-line array, cannot take its address. + // Instead, copy element by element. + if isrc.addr == nil { + for i := 0; i < n; i++ { + dst.Index(i).Set(src.Index(i)) + } + return n } - return int(maplen(m)) -} -// Keys returns a slice containing all the keys present in the map, -// in unspecified order. -func (v *MapValue) Keys() []Value { - tk := v.Type().(*MapType).Key() - m := *(**byte)(v.addr) - mlen := int32(0) - if m != nil { - mlen = maplen(m) + // Copy via memmove. + var da, sa unsafe.Pointer + if idst.kind == Array { + da = idst.addr + } else { + da = unsafe.Pointer((*SliceHeader)(idst.addr).Data) } - it := mapiterinit(m) - a := make([]Value, mlen) - var i int - for i = 0; i < len(a); i++ { - k := MakeZero(tk) - if !mapiterkey(it, (*byte)(k.getAddr())) { - break - } - a[i] = k - mapiternext(it) + if isrc.kind == Array { + sa = isrc.addr + } else { + sa = unsafe.Pointer((*SliceHeader)(isrc.addr).Data) } - return a[0:i] -} - -// MakeMap creates a new map of the specified type. -func MakeMap(typ *MapType) *MapValue { - v := MakeZero(typ).(*MapValue) - *(**byte)(v.addr) = makemap((*runtime.MapType)(unsafe.Pointer(typ))) - return v + memmove(da, sa, uintptr(n)*de.Size()) + return n } /* - * ptr + * constructors */ -// A PtrValue represents a pointer. -type PtrValue struct { - value "ptr" -} - -// IsNil returns whether v is a nil pointer. -func (v *PtrValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } - -// Get returns the uintptr value of v. -// It is mainly useful for printing. -func (v *PtrValue) Get() uintptr { return *(*uintptr)(v.addr) } - -// Set assigns x to v. -// The new value x must have the same type as v, and x.Elem().CanSet() must be true. -func (v *PtrValue) Set(x *PtrValue) { - if x == nil { - *(**uintptr)(v.addr) = nil - return - } - if !v.CanSet() { - panic(cannotSet) +// MakeSlice creates a new zero-initialized slice value +// for the specified slice type, length, and capacity. +func MakeSlice(typ Type, len, cap int) Value { + if typ.Kind() != Slice { + panic("reflect: MakeSlice of non-slice type") } - if x.flag&canStore == 0 { - panic("cannot copy pointer obtained from unexported struct field") + s := &SliceHeader{ + Data: uintptr(unsafe.NewArray(typ.Elem(), cap)), + Len: len, + Cap: cap, } - typesMustMatch(v.typ, x.typ) - // TODO: This will have to move into the runtime - // once the new gc goes in - *(*uintptr)(v.addr) = *(*uintptr)(x.addr) + return valueFromAddr(0, typ, unsafe.Pointer(s)) } -// Set sets v to the value x. -func (v *PtrValue) SetValue(x Value) { - if x == nil { - v.Set(nil) - return +// MakeChan creates a new channel with the specified type and buffer size. +func MakeChan(typ Type, buffer int) Value { + if typ.Kind() != Chan { + panic("reflect: MakeChan of non-chan type") } - v.Set(x.(*PtrValue)) -} - -// PointTo changes v to point to x. -// If x is a nil Value, PointTo sets v to nil. -func (v *PtrValue) PointTo(x Value) { - if x == nil { - *(**uintptr)(v.addr) = nil - return + if buffer < 0 { + panic("MakeChan: negative buffer size") } - if !x.CanSet() { - panic("cannot set x; cannot point to x") + if typ.ChanDir() != BothDir { + panic("MakeChan: unidirectional channel type") } - typesMustMatch(v.typ.(*PtrType).Elem(), x.Type()) - // TODO: This will have to move into the runtime - // once the new gc goes in. - *(*uintptr)(v.addr) = x.UnsafeAddr() + ch := makechan(typ.runtimeType(), uint32(buffer)) + return valueFromIword(0, typ, ch) } -// Elem returns the value that v points to. -// If v is a nil pointer, Elem returns a nil Value. -func (v *PtrValue) Elem() Value { - if v.IsNil() { - return nil - } - flag := canAddr - if v.flag&canStore != 0 { - flag |= canSet | canStore +// MakeMap creates a new map of the specified type. +func MakeMap(typ Type) Value { + if typ.Kind() != Map { + panic("reflect: MakeMap of non-map type") } - return newValue(v.typ.(*PtrType).Elem(), *(*addr)(v.addr), flag) + m := makemap(typ.runtimeType()) + return valueFromIword(0, typ, m) } // Indirect returns the value that v points to. // If v is a nil pointer, Indirect returns a nil Value. // If v is not a pointer, Indirect returns v. func Indirect(v Value) Value { - if pv, ok := v.(*PtrValue); ok { - return pv.Elem() + if v.Kind() != Ptr { + return v } - return v + return v.Elem() } -/* - * struct - */ - -// A StructValue represents a struct value. -type StructValue struct { - value "struct" -} - -// Set assigns x to v. -// The new value x must have the same type as v. -func (v *StructValue) Set(x *StructValue) { - // TODO: This will have to move into the runtime - // once the gc goes in. - if !v.CanSet() { - panic(cannotSet) +// ValueOf returns a new Value initialized to the concrete value +// stored in the interface i. ValueOf(nil) returns the zero Value. +func ValueOf(i interface{}) Value { + if i == nil { + return Value{} } - typesMustMatch(v.typ, x.typ) - memmove(v.addr, x.addr, v.typ.Size()) + // For an interface value with the noAddr bit set, + // the representation is identical to an empty interface. + eface := *(*emptyInterface)(unsafe.Pointer(&i)) + return packValue(0, eface.typ, eface.word) } -// Set sets v to the value x. -func (v *StructValue) SetValue(x Value) { v.Set(x.(*StructValue)) } - -// Field returns the i'th field of the struct. -func (v *StructValue) Field(i int) Value { - t := v.typ.(*StructType) - if i < 0 || i >= t.NumField() { - return nil +// Zero returns a Value representing a zero value for the specified type. +// The result is different from the zero value of the Value struct, +// which represents no value at all. +// For example, Zero(TypeOf(42)) returns a Value with Kind Int and value 0. +func Zero(typ Type) Value { + if typ == nil { + panic("reflect: Zero(nil)") } - f := t.Field(i) - flag := v.flag - if f.PkgPath != "" { - // unexported field - flag &^= canSet | canStore + if typ.Kind() == Ptr || typ.Kind() == UnsafePointer { + return valueFromIword(0, typ, 0) } - return newValue(f.Type, addr(uintptr(v.addr)+f.Offset), flag) + return valueFromAddr(0, typ, unsafe.New(typ)) } -// FieldByIndex returns the nested field corresponding to index. -func (t *StructValue) FieldByIndex(index []int) (v Value) { - v = t - for i, x := range index { - if i > 0 { - if p, ok := v.(*PtrValue); ok { - v = p.Elem() - } - if s, ok := v.(*StructValue); ok { - t = s - } else { - v = nil - return - } - } - v = t.Field(x) +// New returns a Value representing a pointer to a new zero value +// for the specified type. That is, the returned Value's Type is PtrTo(t). +func New(typ Type) Value { + if typ == nil { + panic("reflect: New(nil)") } - return + ptr := unsafe.New(typ) + return valueFromIword(0, PtrTo(typ), iword(uintptr(ptr))) } -// FieldByName returns the struct field with the given name. -// The result is nil if no field was found. -func (t *StructValue) FieldByName(name string) Value { - if f, ok := t.Type().(*StructType).FieldByName(name); ok { - return t.FieldByIndex(f.Index) +// convertForAssignment +func convertForAssignment(what string, addr unsafe.Pointer, dst Type, iv internalValue) internalValue { + if iv.method { + panic(what + ": cannot assign method value to type " + dst.String()) } - return nil -} -// FieldByNameFunc returns the struct field with a name that satisfies the -// match function. -// The result is nil if no field was found. -func (t *StructValue) FieldByNameFunc(match func(string) bool) Value { - if f, ok := t.Type().(*StructType).FieldByNameFunc(match); ok { - return t.FieldByIndex(f.Index) + dst1 := dst.(*commonType) + if directlyAssignable(dst1, iv.typ) { + // Overwrite type so that they match. + // Same memory layout, so no harm done. + iv.typ = dst1 + return iv } - return nil + if implements(dst1, iv.typ) { + if addr == nil { + addr = unsafe.Pointer(new(interface{})) + } + x := iv.Interface() + if dst.NumMethod() == 0 { + *(*interface{})(addr) = x + } else { + ifaceE2I(dst1.runtimeType(), x, addr) + } + iv.addr = addr + iv.word = iword(uintptr(addr)) + iv.typ = dst1 + return iv + } + + // Failed. + panic(what + ": value of type " + iv.typ.String() + " is not assignable to type " + dst.String()) } -// NumField returns the number of fields in the struct. -func (v *StructValue) NumField() int { return v.typ.(*StructType).NumField() } +// implemented in ../pkg/runtime +func chancap(ch iword) int32 +func chanclose(ch iword) +func chanlen(ch iword) int32 +func chanrecv(ch iword, nb bool) (val iword, selected, received bool) +func chansend(ch iword, val iword, nb bool) bool -/* - * constructors - */ +func makechan(typ *runtime.Type, size uint32) (ch iword) +func makemap(t *runtime.Type) iword +func mapaccess(m iword, key iword) (val iword, ok bool) +func mapassign(m iword, key, val iword, ok bool) +func mapiterinit(m iword) *byte +func mapiterkey(it *byte) (key iword, ok bool) +func mapiternext(it *byte) +func maplen(m iword) int32 -// NewValue returns a new Value initialized to the concrete value -// stored in the interface i. NewValue(nil) returns nil. -func NewValue(i interface{}) Value { - if i == nil { - return nil - } - t, a := unsafe.Reflect(i) - return newValue(canonicalize(toType(t)), addr(a), canSet|canAddr|canStore) -} - -func newValue(typ Type, addr addr, flag uint32) Value { - v := value{typ, addr, flag} - switch typ.(type) { - case *ArrayType: - return &ArrayValue{v} - case *BoolType: - return &BoolValue{v} - case *ChanType: - return &ChanValue{v} - case *FloatType: - return &FloatValue{v} - case *FuncType: - return &FuncValue{value: v} - case *ComplexType: - return &ComplexValue{v} - case *IntType: - return &IntValue{v} - case *InterfaceType: - return &InterfaceValue{v} - case *MapType: - return &MapValue{v} - case *PtrType: - return &PtrValue{v} - case *SliceType: - return &SliceValue{v} - case *StringType: - return &StringValue{v} - case *StructType: - return &StructValue{v} - case *UintType: - return &UintValue{v} - case *UnsafePointerType: - return &UnsafePointerValue{v} - } - panic("newValue" + typ.String()) -} - -// MakeZero returns a zero Value for the specified Type. -func MakeZero(typ Type) Value { - if typ == nil { - return nil - } - return newValue(typ, addr(unsafe.New(typ)), canSet|canAddr|canStore) -} +func call(typ *commonType, fnaddr unsafe.Pointer, isInterface bool, isMethod bool, params *unsafe.Pointer, results *unsafe.Pointer) +func ifaceE2I(t *runtime.Type, src interface{}, dst unsafe.Pointer) diff --git a/libgo/go/rpc/server.go b/libgo/go/rpc/server.go index 1cc8c3173a8..acadeec37f0 100644 --- a/libgo/go/rpc/server.go +++ b/libgo/go/rpc/server.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The rpc package provides access to the exported methods of an object across a + Package rpc provides access to the exported methods of an object across a network or other I/O connection. A server registers an object, making it visible as a service with the name of the type of the object. After registration, exported methods of the object will be accessible remotely. A server may register multiple @@ -13,8 +13,11 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method receiver and name are exported, that is, begin with an upper case letter. - - the method has two arguments, both pointers to exported types. + - the method name is exported, that is, begins with an upper case letter. + - the method receiver is exported or local (defined in the package + registering the service). + - the method has two arguments, both exported or local types. + - the method's second argument is a pointer. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -133,13 +136,13 @@ const ( // Precompute the reflect type for os.Error. Can't use os.Error directly // because Typeof takes an empty interface value. This is annoying. var unusedError *os.Error -var typeOfOsError = reflect.Typeof(unusedError).(*reflect.PtrType).Elem() +var typeOfOsError = reflect.TypeOf(unusedError).Elem() type methodType struct { sync.Mutex // protects counters method reflect.Method - ArgType *reflect.PtrType - ReplyType *reflect.PtrType + ArgType reflect.Type + ReplyType reflect.Type numCalls uint } @@ -193,6 +196,14 @@ func isExported(name string) bool { return unicode.IsUpper(rune) } +// Is this type exported or local to this package? +func isExportedOrLocalType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() == "" || isExported(t.Name()) +} + // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method @@ -219,8 +230,8 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E server.serviceMap = make(map[string]*service) } s := new(service) - s.typ = reflect.Typeof(rcvr) - s.rcvr = reflect.NewValue(rcvr) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) sname := reflect.Indirect(s.rcvr).Type().Name() if useName { sname = name @@ -252,22 +263,20 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } - argType, ok := mtype.In(1).(*reflect.PtrType) - if !ok { - log.Println(mname, "arg type not a pointer:", mtype.In(1)) - continue - } - replyType, ok := mtype.In(2).(*reflect.PtrType) - if !ok { - log.Println(mname, "reply type not a pointer:", mtype.In(2)) + // First arg need not be a pointer. + argType := mtype.In(1) + if !isExportedOrLocalType(argType) { + log.Println(mname, "argument type not exported or local:", argType) continue } - if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { - log.Println(mname, "argument type not exported:", argType) + // Second arg must be a pointer. + replyType := mtype.In(2) + if replyType.Kind() != reflect.Ptr { + log.Println("method", mname, "reply type not a pointer:", replyType) continue } - if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { - log.Println(mname, "reply type not exported:", replyType) + if !isExportedOrLocalType(replyType) { + log.Println("method", mname, "reply type not exported or local:", replyType) continue } // Method needs one out: os.Error. @@ -296,12 +305,6 @@ type InvalidRequest struct{} var invalidRequest = InvalidRequest{} -func _new(t *reflect.PtrType) *reflect.PtrValue { - v := reflect.MakeZero(t).(*reflect.PtrValue) - v.PointTo(reflect.MakeZero(t.Elem())) - return v -} - func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { resp := server.getResponse() // Encode the response header @@ -410,8 +413,16 @@ func (server *Server) ServeCodec(codec ServerCodec) { } // Decode the argument value. - argv := _new(mtype.ArgType) - replyv := _new(mtype.ReplyType) + var argv reflect.Value + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + replyv := reflect.New(mtype.ReplyType.Elem()) err = codec.ReadRequestBody(argv.Interface()) if err != nil { if err == os.EOF || err == io.ErrUnexpectedEOF { @@ -423,6 +434,9 @@ func (server *Server) ServeCodec(codec ServerCodec) { server.sendResponse(sending, req, replyv.Interface(), codec, err.String()) continue } + if argIsValue { + argv = argv.Elem() + } go service.call(server, sending, mtype, req, argv, replyv, codec) } codec.Close() diff --git a/libgo/go/rpc/server_test.go b/libgo/go/rpc/server_test.go index d4041ae70ce..cfff0c9ad50 100644 --- a/libgo/go/rpc/server_test.go +++ b/libgo/go/rpc/server_test.go @@ -38,7 +38,9 @@ type Reply struct { type Arith int -func (t *Arith) Add(args *Args, reply *Reply) os.Error { +// Some of Arith's methods have value args, some have pointer args. That's deliberate. + +func (t *Arith) Add(args Args, reply *Reply) os.Error { reply.C = args.A + args.B return nil } @@ -48,7 +50,7 @@ func (t *Arith) Mul(args *Args, reply *Reply) os.Error { return nil } -func (t *Arith) Div(args *Args, reply *Reply) os.Error { +func (t *Arith) Div(args Args, reply *Reply) os.Error { if args.B == 0 { return os.ErrorString("divide by zero") } @@ -61,8 +63,8 @@ func (t *Arith) String(args *Args, reply *string) os.Error { return nil } -func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) { - _, err = fmt.Sscan(*args, &reply.C) +func (t *Arith) Scan(args string, reply *Reply) (err os.Error) { + _, err = fmt.Sscan(args, &reply.C) return } @@ -262,16 +264,11 @@ func testHTTPRPC(t *testing.T, path string) { } } -type ArgNotPointer int type ReplyNotPointer int type ArgNotPublic int type ReplyNotPublic int type local struct{} -func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error { - return nil -} - func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error { return nil } @@ -286,11 +283,7 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error { // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { - err := Register(new(ArgNotPointer)) - if err == nil { - t.Errorf("expected error registering ArgNotPointer") - } - err = Register(new(ReplyNotPointer)) + err := Register(new(ReplyNotPointer)) if err == nil { t.Errorf("expected error registering ReplyNotPointer") } @@ -351,18 +344,26 @@ func testSendDeadlock(client *Client) { client.Call("Arith.Add", args, reply) } -func TestCountMallocs(t *testing.T) { +func dialDirect() (*Client, os.Error) { + return Dial("tcp", serverAddr) +} + +func dialHTTP() (*Client, os.Error) { + return DialHTTP("tcp", httpServerAddr) +} + +func countMallocs(dial func() (*Client, os.Error), t *testing.T) uint64 { once.Do(startServer) - client, err := Dial("tcp", serverAddr) + client, err := dial() if err != nil { - t.Error("error dialing", err) + t.Fatal("error dialing", err) } args := &Args{7, 8} reply := new(Reply) mallocs := 0 - runtime.MemStats.Mallocs const count = 100 for i := 0; i < count; i++ { - err = client.Call("Arith.Add", args, reply) + err := client.Call("Arith.Add", args, reply) if err != nil { t.Errorf("Add: expected no error but got string %q", err.String()) } @@ -371,13 +372,21 @@ func TestCountMallocs(t *testing.T) { } } mallocs += runtime.MemStats.Mallocs - fmt.Printf("mallocs per rpc round trip: %d\n", mallocs/count) + return mallocs / count } -func BenchmarkEndToEnd(b *testing.B) { +func TestCountMallocs(t *testing.T) { + fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) +} + +func TestCountMallocsOverHTTP(t *testing.T) { + fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) +} + +func benchmarkEndToEnd(dial func() (*Client, os.Error), b *testing.B) { b.StopTimer() once.Do(startServer) - client, err := Dial("tcp", serverAddr) + client, err := dial() if err != nil { fmt.Println("error dialing", err) return @@ -399,3 +408,11 @@ func BenchmarkEndToEnd(b *testing.B) { } } } + +func BenchmarkEndToEnd(b *testing.B) { + benchmarkEndToEnd(dialDirect, b) +} + +func BenchmarkEndToEndHTTP(b *testing.B) { + benchmarkEndToEnd(dialHTTP, b) +} diff --git a/libgo/go/runtime/debug/stack.go b/libgo/go/runtime/debug/stack.go index e7d56ac233d..e5fae632b13 100644 --- a/libgo/go/runtime/debug/stack.go +++ b/libgo/go/runtime/debug/stack.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The debug package contains facilities for programs to debug themselves -// while they are running. +// Package debug contains facilities for programs to debug themselves while +// they are running. package debug import ( diff --git a/libgo/go/runtime/extern.go b/libgo/go/runtime/extern.go index c6e664abbbb..9da3423c618 100644 --- a/libgo/go/runtime/extern.go +++ b/libgo/go/runtime/extern.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The runtime package contains operations that interact with Go's runtime system, + Package runtime contains operations that interact with Go's runtime system, such as functions to control goroutines. It also includes the low-level type information used by the reflect package; see reflect's documentation for the programmable interface to the run-time type system. diff --git a/libgo/go/runtime/proc_test.go b/libgo/go/runtime/proc_test.go new file mode 100644 index 00000000000..a15b2d80a4e --- /dev/null +++ b/libgo/go/runtime/proc_test.go @@ -0,0 +1,43 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package runtime_test + +import ( + "runtime" + "testing" +) + +var stop = make(chan bool, 1) + +func perpetuumMobile() { + select { + case <-stop: + default: + go perpetuumMobile() + } +} + +func TestStopTheWorldDeadlock(t *testing.T) { + if testing.Short() { + t.Logf("skipping during short test") + return + } + runtime.GOMAXPROCS(3) + compl := make(chan int, 1) + go func() { + for i := 0; i != 1000; i += 1 { + runtime.GC() + } + compl <- 0 + }() + go func() { + for i := 0; i != 1000; i += 1 { + runtime.GOMAXPROCS(3) + } + }() + go perpetuumMobile() + <-compl + stop <- true +} diff --git a/libgo/go/runtime/type.go b/libgo/go/runtime/type.go index f5f3ef1baad..b59f2e4c383 100644 --- a/libgo/go/runtime/type.go +++ b/libgo/go/runtime/type.go @@ -117,8 +117,9 @@ type UnsafePointerType commonType // ArrayType represents a fixed array type. type ArrayType struct { commonType - elem *Type // array element type - len uintptr + elem *Type // array element type + slice *Type // slice type + len uintptr } // SliceType represents a slice type. diff --git a/libgo/go/scanner/scanner.go b/libgo/go/scanner/scanner.go index 560e595b45a..e79d392f70c 100644 --- a/libgo/go/scanner/scanner.go +++ b/libgo/go/scanner/scanner.go @@ -2,10 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A scanner and tokenizer for UTF-8-encoded text. Takes an io.Reader -// providing the source, which then can be tokenized through repeated calls -// to the Scan function. For compatibility with existing tools, the NUL -// character is not allowed (implementation restriction). +// Package scanner provides a scanner and tokenizer for UTF-8-encoded text. +// It takes an io.Reader providing the source, which then can be tokenized +// through repeated calls to the Scan function. For compatibility with +// existing tools, the NUL character is not allowed (implementation +// restriction). // // By default, a Scanner skips white space and Go comments and recognizes all // literals as defined by the Go language specification. It may be @@ -115,7 +116,7 @@ func TokenString(tok int) string { if s, found := tokenString[tok]; found { return s } - return fmt.Sprintf("U+%04X", tok) + return fmt.Sprintf("%q", string(tok)) } diff --git a/libgo/go/sort/sort.go b/libgo/go/sort/sort.go index c7945d21b61..30b1819af2d 100644 --- a/libgo/go/sort/sort.go +++ b/libgo/go/sort/sort.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The sort package provides primitives for sorting arrays -// and user-defined collections. +// Package sort provides primitives for sorting arrays and user-defined +// collections. package sort // A type, typically a collection, that satisfies sort.Interface can be diff --git a/libgo/go/strconv/atof.go b/libgo/go/strconv/atof.go index 72f162c5134..a91e8bfa4aa 100644 --- a/libgo/go/strconv/atof.go +++ b/libgo/go/strconv/atof.go @@ -2,16 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package strconv implements conversions to and from string representations +// of basic data types. +package strconv + // decimal to binary floating point conversion. // Algorithm: // 1) Store input in multiprecision decimal. // 2) Multiply/divide decimal by powers of two until in range [0.5, 1) // 3) Multiply by 2^precision and round to get mantissa. -// The strconv package implements conversions to and from -// string representations of basic data types. -package strconv - import ( "math" "os" diff --git a/libgo/go/strings/strings.go b/libgo/go/strings/strings.go index 93c7c464738..bfd057180d7 100644 --- a/libgo/go/strings/strings.go +++ b/libgo/go/strings/strings.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A package of simple functions to manipulate strings. +// Package strings implements simple functions to manipulate strings. package strings import ( diff --git a/libgo/go/sync/mutex.go b/libgo/go/sync/mutex.go index da565d38def..13f03cad394 100644 --- a/libgo/go/sync/mutex.go +++ b/libgo/go/sync/mutex.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The sync package provides basic synchronization primitives -// such as mutual exclusion locks. Other than the Once and -// WaitGroup types, most are intended for use by low-level -// library routines. Higher-level synchronization is better -// done via channels and communication. +// Package sync provides basic synchronization primitives such as mutual +// exclusion locks. Other than the Once and WaitGroup types, most are intended +// for use by low-level library routines. Higher-level synchronization is +// better done via channels and communication. package sync import ( diff --git a/libgo/go/syslog/syslog.go b/libgo/go/syslog/syslog.go index 4ada113f1d7..69333721276 100644 --- a/libgo/go/syslog/syslog.go +++ b/libgo/go/syslog/syslog.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The syslog package provides a simple interface to -// the system log service. It can send messages to the -// syslog daemon using UNIX domain sockets, UDP, or +// Package syslog provides a simple interface to the system log service. It +// can send messages to the syslog daemon using UNIX domain sockets, UDP, or // TCP connections. package syslog diff --git a/libgo/go/syslog/syslog_test.go b/libgo/go/syslog/syslog_test.go index 2958bcb1f8e..5c0b3e0c4e2 100644 --- a/libgo/go/syslog/syslog_test.go +++ b/libgo/go/syslog/syslog_test.go @@ -35,7 +35,19 @@ func startServer(done chan<- string) { go runSyslog(c, done) } +func skipNetTest(t *testing.T) bool { + if testing.Short() { + // Depends on syslog daemon running, and sometimes it's not. + t.Logf("skipping syslog test during -short") + return true + } + return false +} + func TestNew(t *testing.T) { + if skipNetTest(t) { + return + } s, err := New(LOG_INFO, "") if err != nil { t.Fatalf("New() failed: %s", err) @@ -45,6 +57,9 @@ func TestNew(t *testing.T) { } func TestNewLogger(t *testing.T) { + if skipNetTest(t) { + return + } f := NewLogger(LOG_INFO, 0) if f == nil { t.Error("NewLogger() failed") @@ -52,6 +67,9 @@ func TestNewLogger(t *testing.T) { } func TestDial(t *testing.T) { + if skipNetTest(t) { + return + } l, err := Dial("", "", LOG_ERR, "syslog_test") if err != nil { t.Fatalf("Dial() failed: %s", err) diff --git a/libgo/go/tabwriter/tabwriter.go b/libgo/go/tabwriter/tabwriter.go index 848703e8ca0..d91a07db242 100644 --- a/libgo/go/tabwriter/tabwriter.go +++ b/libgo/go/tabwriter/tabwriter.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The tabwriter package implements a write filter (tabwriter.Writer) -// that translates tabbed columns in input into properly aligned text. +// Package tabwriter implements a write filter (tabwriter.Writer) that +// translates tabbed columns in input into properly aligned text. // // The package is using the Elastic Tabstops algorithm described at // http://nickgravgaard.com/elastictabstops/index.html. diff --git a/libgo/go/template/template.go b/libgo/go/template/template.go index ba06de4e3ab..25320785223 100644 --- a/libgo/go/template/template.go +++ b/libgo/go/template/template.go @@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. /* - Data-driven templates for generating textual output such as - HTML. + Package template implements data-driven templates for generating textual + output such as HTML. Templates are executed by applying them to a data structure. Annotations in the template refer to elements of the data @@ -621,7 +621,7 @@ func (t *Template) parse() { // Evaluate interfaces and pointers looking for a value that can look up the name, via a // struct field, method, or map key, and return the result of the lookup. func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value { - for v != nil { + for v.IsValid() { typ := v.Type() if n := v.Type().NumMethod(); n > 0 { for i := 0; i < n; i++ { @@ -635,23 +635,23 @@ func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value } } } - switch av := v.(type) { - case *reflect.PtrValue: + switch av := v; av.Kind() { + case reflect.Ptr: v = av.Elem() - case *reflect.InterfaceValue: + case reflect.Interface: v = av.Elem() - case *reflect.StructValue: + case reflect.Struct: if !isExported(name) { t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type()) } return av.FieldByName(name) - case *reflect.MapValue: - if v := av.Elem(reflect.NewValue(name)); v != nil { + case reflect.Map: + if v := av.MapIndex(reflect.ValueOf(name)); v.IsValid() { return v } - return reflect.MakeZero(typ.(*reflect.MapType).Elem()) + return reflect.Zero(typ.Elem()) default: - return nil + return reflect.Value{} } } return v @@ -661,8 +661,8 @@ func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value // It is forgiving: if the value is not a pointer, it returns it rather than giving // an error. If the pointer is nil, it is returned as is. func indirectPtr(v reflect.Value, numLevels int) reflect.Value { - for i := numLevels; v != nil && i > 0; i++ { - if p, ok := v.(*reflect.PtrValue); ok { + for i := numLevels; v.IsValid() && i > 0; i++ { + if p := v; p.Kind() == reflect.Ptr { if p.IsNil() { return v } @@ -677,11 +677,11 @@ func indirectPtr(v reflect.Value, numLevels int) reflect.Value { // Walk v through pointers and interfaces, extracting the elements within. func indirect(v reflect.Value) reflect.Value { loop: - for v != nil { - switch av := v.(type) { - case *reflect.PtrValue: + for v.IsValid() { + switch av := v; av.Kind() { + case reflect.Ptr: v = av.Elem() - case *reflect.InterfaceValue: + case reflect.Interface: v = av.Elem() default: break loop @@ -708,8 +708,8 @@ func (t *Template) findVar(st *state, s string) reflect.Value { for _, elem := range strings.Split(s, ".", -1) { // Look up field; data must be a struct or map. data = t.lookup(st, data, elem) - if data == nil { - return nil + if !data.IsValid() { + return reflect.Value{} } } return indirectPtr(data, numStars) @@ -718,21 +718,21 @@ func (t *Template) findVar(st *state, s string) reflect.Value { // Is there no data to look at? func empty(v reflect.Value) bool { v = indirect(v) - if v == nil { + if !v.IsValid() { return true } - switch v := v.(type) { - case *reflect.BoolValue: - return v.Get() == false - case *reflect.StringValue: - return v.Get() == "" - case *reflect.StructValue: + switch v.Kind() { + case reflect.Bool: + return v.Bool() == false + case reflect.String: + return v.String() == "" + case reflect.Struct: return false - case *reflect.MapValue: + case reflect.Map: return false - case *reflect.ArrayValue: + case reflect.Array: return v.Len() == 0 - case *reflect.SliceValue: + case reflect.Slice: return v.Len() == 0 } return false @@ -741,7 +741,7 @@ func empty(v reflect.Value) bool { // Look up a variable or method, up through the parent if necessary. func (t *Template) varValue(name string, st *state) reflect.Value { field := t.findVar(st, name) - if field == nil { + if !field.IsValid() { if st.parent == nil { t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type()) } @@ -797,7 +797,7 @@ func (t *Template) executeElement(i int, st *state) int { return elem.end } e := t.elems.At(i) - t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e) + t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.ValueOf(e).Interface(), e) return 0 } @@ -812,7 +812,7 @@ func (t *Template) execute(start, end int, st *state) { func (t *Template) executeSection(s *sectionElement, st *state) { // Find driver data for this section. It must be in the current struct. field := t.varValue(s.field, st) - if field == nil { + if !field.IsValid() { t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type()) } st = st.clone(field) @@ -835,29 +835,30 @@ func (t *Template) executeSection(s *sectionElement, st *state) { } // Return the result of calling the Iter method on v, or nil. -func iter(v reflect.Value) *reflect.ChanValue { +func iter(v reflect.Value) reflect.Value { for j := 0; j < v.Type().NumMethod(); j++ { mth := v.Type().Method(j) fv := v.Method(j) - ft := fv.Type().(*reflect.FuncType) + ft := fv.Type() // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue. if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 { continue } - ct, ok := ft.Out(0).(*reflect.ChanType) - if !ok || ct.Dir()&reflect.RecvDir == 0 { + ct := ft.Out(0) + if ct.Kind() != reflect.Chan || + ct.ChanDir()&reflect.RecvDir == 0 { continue } - return fv.Call(nil)[0].(*reflect.ChanValue) + return fv.Call(nil)[0] } - return nil + return reflect.Value{} } // Execute a .repeated section func (t *Template) executeRepeated(r *repeatedElement, st *state) { // Find driver data for this section. It must be in the current struct. field := t.varValue(r.field, st) - if field == nil { + if !field.IsValid() { t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type()) } field = indirect(field) @@ -885,15 +886,15 @@ func (t *Template) executeRepeated(r *repeatedElement, st *state) { } } - if array, ok := field.(reflect.ArrayOrSliceValue); ok { + if array := field; array.Kind() == reflect.Array || array.Kind() == reflect.Slice { for j := 0; j < array.Len(); j++ { - loopBody(st.clone(array.Elem(j))) + loopBody(st.clone(array.Index(j))) } - } else if m, ok := field.(*reflect.MapValue); ok { - for _, key := range m.Keys() { - loopBody(st.clone(m.Elem(key))) + } else if m := field; m.Kind() == reflect.Map { + for _, key := range m.MapKeys() { + loopBody(st.clone(m.MapIndex(key))) } - } else if ch := iter(field); ch != nil { + } else if ch := iter(field); ch.IsValid() { for { e, ok := ch.Recv() if !ok { @@ -979,7 +980,7 @@ func (t *Template) ParseFile(filename string) (err os.Error) { // generating output to wr. func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) { // Extract the driver data. - val := reflect.NewValue(data) + val := reflect.ValueOf(data) defer checkError(&err) t.p = 0 t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr}) diff --git a/libgo/go/testing/iotest/reader.go b/libgo/go/testing/iotest/reader.go index 647520a09fc..e4003d74450 100644 --- a/libgo/go/testing/iotest/reader.go +++ b/libgo/go/testing/iotest/reader.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The iotest package implements Readers and Writers -// useful only for testing. +// Package iotest implements Readers and Writers useful only for testing. package iotest import ( diff --git a/libgo/go/testing/quick/quick.go b/libgo/go/testing/quick/quick.go index a5568b04830..756a60e1352 100644 --- a/libgo/go/testing/quick/quick.go +++ b/libgo/go/testing/quick/quick.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements utility functions to help with black box testing. +// Package quick implements utility functions to help with black box testing. package quick import ( @@ -53,96 +53,93 @@ const complexSize = 50 // If the type implements the Generator interface, that will be used. // Note: in order to create arbitrary values for structs, all the members must be public. func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { - if m, ok := reflect.MakeZero(t).Interface().(Generator); ok { + if m, ok := reflect.Zero(t).Interface().(Generator); ok { return m.Generate(rand, complexSize), true } - switch concrete := t.(type) { - case *reflect.BoolType: - return reflect.NewValue(rand.Int()&1 == 0), true - case *reflect.FloatType, *reflect.IntType, *reflect.UintType, *reflect.ComplexType: - switch t.Kind() { - case reflect.Float32: - return reflect.NewValue(randFloat32(rand)), true - case reflect.Float64: - return reflect.NewValue(randFloat64(rand)), true - case reflect.Complex64: - return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true - case reflect.Complex128: - return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true - case reflect.Int16: - return reflect.NewValue(int16(randInt64(rand))), true - case reflect.Int32: - return reflect.NewValue(int32(randInt64(rand))), true - case reflect.Int64: - return reflect.NewValue(randInt64(rand)), true - case reflect.Int8: - return reflect.NewValue(int8(randInt64(rand))), true - case reflect.Int: - return reflect.NewValue(int(randInt64(rand))), true - case reflect.Uint16: - return reflect.NewValue(uint16(randInt64(rand))), true - case reflect.Uint32: - return reflect.NewValue(uint32(randInt64(rand))), true - case reflect.Uint64: - return reflect.NewValue(uint64(randInt64(rand))), true - case reflect.Uint8: - return reflect.NewValue(uint8(randInt64(rand))), true - case reflect.Uint: - return reflect.NewValue(uint(randInt64(rand))), true - case reflect.Uintptr: - return reflect.NewValue(uintptr(randInt64(rand))), true - } - case *reflect.MapType: + switch concrete := t; concrete.Kind() { + case reflect.Bool: + return reflect.ValueOf(rand.Int()&1 == 0), true + case reflect.Float32: + return reflect.ValueOf(randFloat32(rand)), true + case reflect.Float64: + return reflect.ValueOf(randFloat64(rand)), true + case reflect.Complex64: + return reflect.ValueOf(complex(randFloat32(rand), randFloat32(rand))), true + case reflect.Complex128: + return reflect.ValueOf(complex(randFloat64(rand), randFloat64(rand))), true + case reflect.Int16: + return reflect.ValueOf(int16(randInt64(rand))), true + case reflect.Int32: + return reflect.ValueOf(int32(randInt64(rand))), true + case reflect.Int64: + return reflect.ValueOf(randInt64(rand)), true + case reflect.Int8: + return reflect.ValueOf(int8(randInt64(rand))), true + case reflect.Int: + return reflect.ValueOf(int(randInt64(rand))), true + case reflect.Uint16: + return reflect.ValueOf(uint16(randInt64(rand))), true + case reflect.Uint32: + return reflect.ValueOf(uint32(randInt64(rand))), true + case reflect.Uint64: + return reflect.ValueOf(uint64(randInt64(rand))), true + case reflect.Uint8: + return reflect.ValueOf(uint8(randInt64(rand))), true + case reflect.Uint: + return reflect.ValueOf(uint(randInt64(rand))), true + case reflect.Uintptr: + return reflect.ValueOf(uintptr(randInt64(rand))), true + case reflect.Map: numElems := rand.Intn(complexSize) m := reflect.MakeMap(concrete) for i := 0; i < numElems; i++ { key, ok1 := Value(concrete.Key(), rand) value, ok2 := Value(concrete.Elem(), rand) if !ok1 || !ok2 { - return nil, false + return reflect.Value{}, false } - m.SetElem(key, value) + m.SetMapIndex(key, value) } return m, true - case *reflect.PtrType: + case reflect.Ptr: v, ok := Value(concrete.Elem(), rand) if !ok { - return nil, false + return reflect.Value{}, false } - p := reflect.MakeZero(concrete) - p.(*reflect.PtrValue).PointTo(v) + p := reflect.New(concrete.Elem()) + p.Elem().Set(v) return p, true - case *reflect.SliceType: + case reflect.Slice: numElems := rand.Intn(complexSize) s := reflect.MakeSlice(concrete, numElems, numElems) for i := 0; i < numElems; i++ { v, ok := Value(concrete.Elem(), rand) if !ok { - return nil, false + return reflect.Value{}, false } - s.Elem(i).SetValue(v) + s.Index(i).Set(v) } return s, true - case *reflect.StringType: + case reflect.String: numChars := rand.Intn(complexSize) codePoints := make([]int, numChars) for i := 0; i < numChars; i++ { codePoints[i] = rand.Intn(0x10ffff) } - return reflect.NewValue(string(codePoints)), true - case *reflect.StructType: - s := reflect.MakeZero(t).(*reflect.StructValue) + return reflect.ValueOf(string(codePoints)), true + case reflect.Struct: + s := reflect.New(t).Elem() for i := 0; i < s.NumField(); i++ { v, ok := Value(concrete.Field(i).Type, rand) if !ok { - return nil, false + return reflect.Value{}, false } - s.Field(i).SetValue(v) + s.Field(i).Set(v) } return s, true default: - return nil, false + return reflect.Value{}, false } return @@ -247,7 +244,7 @@ func Check(function interface{}, config *Config) (err os.Error) { err = SetupError("function returns more than one value.") return } - if _, ok := fType.Out(0).(*reflect.BoolType); !ok { + if fType.Out(0).Kind() != reflect.Bool { err = SetupError("function does not return a bool") return } @@ -262,7 +259,7 @@ func Check(function interface{}, config *Config) (err os.Error) { return } - if !f.Call(arguments)[0].(*reflect.BoolValue).Get() { + if !f.Call(arguments)[0].Bool() { err = &CheckError{i + 1, toInterfaces(arguments)} return } @@ -320,7 +317,7 @@ func CheckEqual(f, g interface{}, config *Config) (err os.Error) { // arbitraryValues writes Values to args such that args contains Values // suitable for calling f. -func arbitraryValues(args []reflect.Value, f *reflect.FuncType, config *Config, rand *rand.Rand) (err os.Error) { +func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand *rand.Rand) (err os.Error) { if config.Values != nil { config.Values(args, rand) return @@ -338,12 +335,13 @@ func arbitraryValues(args []reflect.Value, f *reflect.FuncType, config *Config, return } -func functionAndType(f interface{}) (v *reflect.FuncValue, t *reflect.FuncType, ok bool) { - v, ok = reflect.NewValue(f).(*reflect.FuncValue) +func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) { + v = reflect.ValueOf(f) + ok = v.Kind() == reflect.Func if !ok { return } - t = v.Type().(*reflect.FuncType) + t = v.Type() return } diff --git a/libgo/go/testing/quick/quick_test.go b/libgo/go/testing/quick/quick_test.go index b126e4a1669..f2618c3c255 100644 --- a/libgo/go/testing/quick/quick_test.go +++ b/libgo/go/testing/quick/quick_test.go @@ -102,7 +102,7 @@ type myStruct struct { } func (m myStruct) Generate(r *rand.Rand, _ int) reflect.Value { - return reflect.NewValue(myStruct{x: 42}) + return reflect.ValueOf(myStruct{x: 42}) } func myStructProperty(in myStruct) bool { return in.x == 42 } diff --git a/libgo/go/testing/script/script.go b/libgo/go/testing/script/script.go index b341b1f896b..afb286f5b86 100644 --- a/libgo/go/testing/script/script.go +++ b/libgo/go/testing/script/script.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package aids in the testing of code that uses channels. +// Package script aids in the testing of code that uses channels. package script import ( @@ -134,19 +134,19 @@ type empty struct { } func newEmptyInterface(e empty) reflect.Value { - return reflect.NewValue(e).(*reflect.StructValue).Field(0) + return reflect.ValueOf(e).Field(0) } func (s Send) send() { // With reflect.ChanValue.Send, we must match the types exactly. So, if // s.Channel is a chan interface{} we convert s.Value to an interface{} // first. - c := reflect.NewValue(s.Channel).(*reflect.ChanValue) + c := reflect.ValueOf(s.Channel) var v reflect.Value - if iface, ok := c.Type().(*reflect.ChanType).Elem().(*reflect.InterfaceType); ok && iface.NumMethod() == 0 { + if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 { v = newEmptyInterface(empty{s.Value}) } else { - v = reflect.NewValue(s.Value) + v = reflect.ValueOf(s.Value) } c.Send(v) } @@ -162,7 +162,7 @@ func (s Close) getSend() sendAction { return s } func (s Close) getChannel() interface{} { return s.Channel } -func (s Close) send() { reflect.NewValue(s.Channel).(*reflect.ChanValue).Close() } +func (s Close) send() { reflect.ValueOf(s.Channel).Close() } // A ReceivedUnexpected error results if no active Events match a value // received from a channel. @@ -278,7 +278,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { continue } c := event.action.getChannel() - if _, ok := reflect.NewValue(c).(*reflect.ChanValue); !ok { + if reflect.ValueOf(c).Kind() != reflect.Chan { return nil, SetupError("one of the channel values is not a channel") } @@ -303,7 +303,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { // channel repeatedly, wrapping them up as either a channelRecv or // channelClosed structure, and forwards them to the multiplex channel. func recvValues(multiplex chan<- interface{}, channel interface{}) { - c := reflect.NewValue(channel).(*reflect.ChanValue) + c := reflect.ValueOf(channel) for { v, ok := c.Recv() diff --git a/libgo/go/testing/testing.go b/libgo/go/testing/testing.go index 1e65528ef96..8781b207def 100644 --- a/libgo/go/testing/testing.go +++ b/libgo/go/testing/testing.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The testing package provides support for automated testing of Go packages. +// Package testing provides support for automated testing of Go packages. // It is intended to be used in concert with the ``gotest'' utility, which automates // execution of any function of the form // func TestXxx(*testing.T) diff --git a/libgo/go/time/time.go b/libgo/go/time/time.go index 40338f7752a..a0480786aa5 100644 --- a/libgo/go/time/time.go +++ b/libgo/go/time/time.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The time package provides functionality for measuring and -// displaying time. +// Package time provides functionality for measuring and displaying time. package time // Days of the week. diff --git a/libgo/go/try/try.go b/libgo/go/try/try.go index af31d0d2cfc..2a3dbf9870e 100644 --- a/libgo/go/try/try.go +++ b/libgo/go/try/try.go @@ -67,7 +67,7 @@ func printSlice(firstArg string, args []interface{}) { func tryMethods(pkg, firstArg string, args []interface{}) { defer func() { recover() }() // Is the first argument something with methods? - v := reflect.NewValue(args[0]) + v := reflect.ValueOf(args[0]) typ := v.Type() if typ.NumMethod() == 0 { return @@ -90,13 +90,13 @@ func tryMethod(pkg, firstArg string, method reflect.Method, args []interface{}) // tryFunction sees if fn satisfies the arguments. func tryFunction(pkg, name string, fn interface{}, args []interface{}) { defer func() { recover() }() - rfn := reflect.NewValue(fn).(*reflect.FuncValue) - typ := rfn.Type().(*reflect.FuncType) + rfn := reflect.ValueOf(fn) + typ := rfn.Type() tryOneFunction(pkg, "", name, typ, rfn, args) } // tryOneFunction is the common code for tryMethod and tryFunction. -func tryOneFunction(pkg, firstArg, name string, typ *reflect.FuncType, rfn *reflect.FuncValue, args []interface{}) { +func tryOneFunction(pkg, firstArg, name string, typ reflect.Type, rfn reflect.Value, args []interface{}) { // Any results? if typ.NumOut() == 0 { return // Nothing to do. @@ -120,7 +120,7 @@ func tryOneFunction(pkg, firstArg, name string, typ *reflect.FuncType, rfn *refl // Build the call args. argsVal := make([]reflect.Value, typ.NumIn()+typ.NumOut()) for i, a := range args { - argsVal[i] = reflect.NewValue(a) + argsVal[i] = reflect.ValueOf(a) } // Call the function and see if the results are as expected. resultVal := rfn.Call(argsVal[:typ.NumIn()]) @@ -161,12 +161,12 @@ func tryOneFunction(pkg, firstArg, name string, typ *reflect.FuncType, rfn *refl // compatible reports whether the argument is compatible with the type. func compatible(arg interface{}, typ reflect.Type) bool { - if reflect.Typeof(arg) == typ { + if reflect.TypeOf(arg) == typ { return true } if arg == nil { // nil is OK if the type is an interface. - if _, ok := typ.(*reflect.InterfaceType); ok { + if typ.Kind() == reflect.Interface { return true } } diff --git a/libgo/go/unicode/letter.go b/libgo/go/unicode/letter.go index 9380624fd9e..382c6eb3f47 100644 --- a/libgo/go/unicode/letter.go +++ b/libgo/go/unicode/letter.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides data and functions to test some properties of Unicode code points. +// Package unicode provides data and functions to test some properties of +// Unicode code points. package unicode const ( diff --git a/libgo/go/utf8/utf8.go b/libgo/go/utf8/utf8.go index 455499e4d95..f542358d6dc 100644 --- a/libgo/go/utf8/utf8.go +++ b/libgo/go/utf8/utf8.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Functions and constants to support text encoded in UTF-8. -// This package calls a Unicode character a rune for brevity. +// Package utf8 implements functions and constants to support text encoded in +// UTF-8. This package calls a Unicode character a rune for brevity. package utf8 import "unicode" // only needed for a couple of constants diff --git a/libgo/go/websocket/server.go b/libgo/go/websocket/server.go index 1119b2d34eb..376265236e2 100644 --- a/libgo/go/websocket/server.go +++ b/libgo/go/websocket/server.go @@ -150,6 +150,7 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } ws := newConn(origin, location, protocol, buf, rwc) + ws.Request = req f(ws) } diff --git a/libgo/go/websocket/websocket.go b/libgo/go/websocket/websocket.go index d5996abe1a5..edde61b4a76 100644 --- a/libgo/go/websocket/websocket.go +++ b/libgo/go/websocket/websocket.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The websocket package implements a client and server for the Web Socket protocol. +// Package websocket implements a client and server for the Web Socket protocol. // The protocol is defined at http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol package websocket @@ -13,6 +13,7 @@ import ( "bufio" "crypto/md5" "encoding/binary" + "http" "io" "net" "os" @@ -43,6 +44,8 @@ type Conn struct { Location string // The subprotocol for the Web Socket. Protocol string + // The initial http Request (for the Server side only). + Request *http.Request buf *bufio.ReadWriter rwc io.ReadWriteCloser diff --git a/libgo/go/websocket/websocket_test.go b/libgo/go/websocket/websocket_test.go index 8b3cf8925a9..10f88dfd1a0 100644 --- a/libgo/go/websocket/websocket_test.go +++ b/libgo/go/websocket/websocket_test.go @@ -186,11 +186,12 @@ func TestTrailingSpaces(t *testing.T) { once.Do(startServer) for i := 0; i < 30; i++ { // body - _, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", - "http://localhost/") + ws, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", "http://localhost/") if err != nil { - panic("Dial failed: " + err.String()) + t.Error("Dial failed:", err.String()) + break } + ws.Close() } } diff --git a/libgo/go/xml/read.go b/libgo/go/xml/read.go index 9ae3bb8eee9..e2b349c3ffb 100644 --- a/libgo/go/xml/read.go +++ b/libgo/go/xml/read.go @@ -139,8 +139,8 @@ import ( // to a freshly allocated value and then mapping the element to that value. // func Unmarshal(r io.Reader, val interface{}) os.Error { - v, ok := reflect.NewValue(val).(*reflect.PtrValue) - if !ok { + v := reflect.ValueOf(val) + if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } p := NewParser(r) @@ -176,8 +176,8 @@ func (e *TagPathError) String() string { // Passing a nil start element indicates that Unmarshal should // read the token stream to find the start element. func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error { - v, ok := reflect.NewValue(val).(*reflect.PtrValue) - if !ok { + v := reflect.ValueOf(val) + if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } return p.unmarshal(v.Elem(), start) @@ -219,14 +219,11 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { } } - if pv, ok := val.(*reflect.PtrValue); ok { - if pv.Get() == 0 { - zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem()) - pv.PointTo(zv) - val = zv - } else { - val = pv.Elem() + if pv := val; pv.Kind() == reflect.Ptr { + if pv.IsNil() { + pv.Set(reflect.New(pv.Type().Elem())) } + val = pv.Elem() } var ( @@ -237,17 +234,17 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { saveXML reflect.Value saveXMLIndex int saveXMLData []byte - sv *reflect.StructValue - styp *reflect.StructType + sv reflect.Value + styp reflect.Type fieldPaths map[string]pathInfo ) - switch v := val.(type) { + switch v := val; v.Kind() { default: return os.ErrorString("unknown type " + v.Type().String()) - case *reflect.SliceValue: - typ := v.Type().(*reflect.SliceType) + case reflect.Slice: + typ := v.Type() if typ.Elem().Kind() == reflect.Uint8 { // []byte saveData = v @@ -269,23 +266,23 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { v.SetLen(n + 1) // Recur to read element into slice. - if err := p.unmarshal(v.Elem(n), start); err != nil { + if err := p.unmarshal(v.Index(n), start); err != nil { v.SetLen(n) return err } return nil - case *reflect.BoolValue, *reflect.FloatValue, *reflect.IntValue, *reflect.UintValue, *reflect.StringValue: + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: saveData = v - case *reflect.StructValue: + case reflect.Struct: if _, ok := v.Interface().(Name); ok { - v.Set(reflect.NewValue(start.Name).(*reflect.StructValue)) + v.Set(reflect.ValueOf(start.Name)) break } sv = v - typ := sv.Type().(*reflect.StructType) + typ := sv.Type() styp = typ // Assign name. if f, ok := typ.FieldByName("XMLName"); ok { @@ -316,7 +313,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { if _, ok := v.Interface().(Name); !ok { return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name") } - v.(*reflect.StructValue).Set(reflect.NewValue(start.Name).(*reflect.StructValue)) + v.Set(reflect.ValueOf(start.Name)) } // Assign attributes. @@ -325,8 +322,8 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { f := typ.Field(i) switch f.Tag { case "attr": - strv, ok := sv.FieldByIndex(f.Index).(*reflect.StringValue) - if !ok { + strv := sv.FieldByIndex(f.Index) + if strv.Kind() != reflect.String { return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string") } // Look for attribute. @@ -338,20 +335,20 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { break } } - strv.Set(val) + strv.SetString(val) case "comment": - if saveComment == nil { + if !saveComment.IsValid() { saveComment = sv.FieldByIndex(f.Index) } case "chardata": - if saveData == nil { + if !saveData.IsValid() { saveData = sv.FieldByIndex(f.Index) } case "innerxml": - if saveXML == nil { + if !saveXML.IsValid() { saveXML = sv.FieldByIndex(f.Index) if p.saved == nil { saveXMLIndex = 0 @@ -387,7 +384,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { Loop: for { var savedOffset int - if saveXML != nil { + if saveXML.IsValid() { savedOffset = p.savedOffset() } tok, err := p.Token() @@ -398,7 +395,7 @@ Loop: case StartElement: // Sub-element. // Look up by tag name. - if sv != nil { + if sv.IsValid() { k := fieldName(t.Name.Local) if fieldPaths != nil { @@ -437,7 +434,7 @@ Loop: } case EndElement: - if saveXML != nil { + if saveXML.IsValid() { saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset] if saveXMLIndex == 0 { p.saved = nil @@ -446,12 +443,12 @@ Loop: break Loop case CharData: - if saveData != nil { + if saveData.IsValid() { data = append(data, t...) } case Comment: - if saveComment != nil { + if saveComment.IsValid() { comment = append(comment, t...) } } @@ -479,50 +476,50 @@ Loop: } // Save accumulated data and comments - switch t := saveData.(type) { - case nil: + switch t := saveData; t.Kind() { + case reflect.Invalid: // Probably a comment, handled below default: return os.ErrorString("cannot happen: unknown type " + t.Type().String()) - case *reflect.IntValue: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if !getInt64() { return err } - t.Set(itmp) - case *reflect.UintValue: + t.SetInt(itmp) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if !getUint64() { return err } - t.Set(utmp) - case *reflect.FloatValue: + t.SetUint(utmp) + case reflect.Float32, reflect.Float64: if !getFloat64() { return err } - t.Set(ftmp) - case *reflect.BoolValue: + t.SetFloat(ftmp) + case reflect.Bool: value, err := strconv.Atob(strings.TrimSpace(string(data))) if err != nil { return err } - t.Set(value) - case *reflect.StringValue: - t.Set(string(data)) - case *reflect.SliceValue: - t.Set(reflect.NewValue(data).(*reflect.SliceValue)) + t.SetBool(value) + case reflect.String: + t.SetString(string(data)) + case reflect.Slice: + t.Set(reflect.ValueOf(data)) } - switch t := saveComment.(type) { - case *reflect.StringValue: - t.Set(string(comment)) - case *reflect.SliceValue: - t.Set(reflect.NewValue(comment).(*reflect.SliceValue)) + switch t := saveComment; t.Kind() { + case reflect.String: + t.SetString(string(comment)) + case reflect.Slice: + t.Set(reflect.ValueOf(comment)) } - switch t := saveXML.(type) { - case *reflect.StringValue: - t.Set(string(saveXMLData)) - case *reflect.SliceValue: - t.Set(reflect.NewValue(saveXMLData).(*reflect.SliceValue)) + switch t := saveXML; t.Kind() { + case reflect.String: + t.SetString(string(saveXMLData)) + case reflect.Slice: + t.Set(reflect.ValueOf(saveXMLData)) } return nil @@ -537,7 +534,7 @@ type pathInfo struct { // paths map with all paths leading to it ("a", "a>b", and "a>b>c"). // It is okay for paths to share a common, shorter prefix but not ok // for one path to itself be a prefix of another. -func addFieldPath(sv *reflect.StructValue, paths map[string]pathInfo, path string, fieldIdx []int) os.Error { +func addFieldPath(sv reflect.Value, paths map[string]pathInfo, path string, fieldIdx []int) os.Error { if info, found := paths[path]; found { return tagError(sv, info.fieldIdx, fieldIdx) } @@ -560,8 +557,8 @@ func addFieldPath(sv *reflect.StructValue, paths map[string]pathInfo, path strin } -func tagError(sv *reflect.StructValue, idx1 []int, idx2 []int) os.Error { - t := sv.Type().(*reflect.StructType) +func tagError(sv reflect.Value, idx1 []int, idx2 []int) os.Error { + t := sv.Type() f1 := t.FieldByIndex(idx1) f2 := t.FieldByIndex(idx2) return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag} @@ -569,7 +566,7 @@ func tagError(sv *reflect.StructValue, idx1 []int, idx2 []int) os.Error { // unmarshalPaths walks down an XML structure looking for // wanted paths, and calls unmarshal on them. -func (p *Parser) unmarshalPaths(sv *reflect.StructValue, paths map[string]pathInfo, path string, start *StartElement) os.Error { +func (p *Parser) unmarshalPaths(sv reflect.Value, paths map[string]pathInfo, path string, start *StartElement) os.Error { if info, _ := paths[path]; info.complete { return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start) } diff --git a/libgo/go/xml/read_test.go b/libgo/go/xml/read_test.go index a6b9a8ed18b..d4ae3700dba 100644 --- a/libgo/go/xml/read_test.go +++ b/libgo/go/xml/read_test.go @@ -288,9 +288,7 @@ var pathTests = []interface{}{ func TestUnmarshalPaths(t *testing.T) { for _, pt := range pathTests { - p := reflect.MakeZero(reflect.NewValue(pt).Type()).(*reflect.PtrValue) - p.PointTo(reflect.MakeZero(p.Type().(*reflect.PtrType).Elem())) - v := p.Interface() + v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() if err := Unmarshal(StringReader(pathTestString), v); err != nil { t.Fatalf("Unmarshal: %s", err) } @@ -315,8 +313,8 @@ type BadPathTestB struct { var badPathTests = []struct { v, e interface{} }{ - {&BadPathTestA{}, &TagPathError{reflect.Typeof(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, - {&BadPathTestB{}, &TagPathError{reflect.Typeof(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, + {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, + {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, } func TestUnmarshalBadPaths(t *testing.T) { diff --git a/libgo/go/xml/xml.go b/libgo/go/xml/xml.go index f92abe82560..42d8b986ecc 100644 --- a/libgo/go/xml/xml.go +++ b/libgo/go/xml/xml.go @@ -163,6 +163,13 @@ type Parser struct { // "quot": `"`, Entity map[string]string + // CharsetReader, if non-nil, defines a function to generate + // charset-conversion readers, converting from the provided + // non-UTF-8 charset into UTF-8. If CharsetReader is nil or + // returns an error, parsing stops with an error. One of the + // the CharsetReader's result values must be non-nil. + CharsetReader func(charset string, input io.Reader) (io.Reader, os.Error) + r io.ByteReader buf bytes.Buffer saved *bytes.Buffer @@ -186,17 +193,7 @@ func NewParser(r io.Reader) *Parser { line: 1, Strict: true, } - - // Get efficient byte at a time reader. - // Assume that if reader has its own - // ReadByte, it's efficient enough. - // Otherwise, use bufio. - if rb, ok := r.(io.ByteReader); ok { - p.r = rb - } else { - p.r = bufio.NewReader(r) - } - + p.switchToReader(r) return p } @@ -290,6 +287,18 @@ func (p *Parser) translate(n *Name, isElementName bool) { } } +func (p *Parser) switchToReader(r io.Reader) { + // Get efficient byte at a time reader. + // Assume that if reader has its own + // ReadByte, it's efficient enough. + // Otherwise, use bufio. + if rb, ok := r.(io.ByteReader); ok { + p.r = rb + } else { + p.r = bufio.NewReader(r) + } +} + // Parsing state - stack holds old name space translations // and the current set of open elements. The translations to pop when // ending a given tag are *below* it on the stack, which is @@ -487,6 +496,25 @@ func (p *Parser) RawToken() (Token, os.Error) { } data := p.buf.Bytes() data = data[0 : len(data)-2] // chop ?> + + if target == "xml" { + enc := procInstEncoding(string(data)) + if enc != "" && enc != "utf-8" && enc != "UTF-8" { + if p.CharsetReader == nil { + p.err = fmt.Errorf("xml: encoding %q declared but Parser.CharsetReader is nil", enc) + return nil, p.err + } + newr, err := p.CharsetReader(enc, p.r.(io.Reader)) + if err != nil { + p.err = fmt.Errorf("xml: opening charset %q: %v", enc, err) + return nil, p.err + } + if newr == nil { + panic("CharsetReader returned a nil Reader for charset " + enc) + } + p.switchToReader(newr) + } + } return ProcInst{target, data}, nil case '!': @@ -1633,3 +1661,26 @@ func Escape(w io.Writer, s []byte) { } w.Write(s[last:]) } + +// procInstEncoding parses the `encoding="..."` or `encoding='...'` +// value out of the provided string, returning "" if not found. +func procInstEncoding(s string) string { + // TODO: this parsing is somewhat lame and not exact. + // It works for all actual cases, though. + idx := strings.Index(s, "encoding=") + if idx == -1 { + return "" + } + v := s[idx+len("encoding="):] + if v == "" { + return "" + } + if v[0] != '\'' && v[0] != '"' { + return "" + } + idx = strings.IndexRune(v[1:], int(v[0])) + if idx == -1 { + return "" + } + return v[1 : idx+1] +} diff --git a/libgo/go/xml/xml_test.go b/libgo/go/xml/xml_test.go index 887bc3d140f..4e51cd53af1 100644 --- a/libgo/go/xml/xml_test.go +++ b/libgo/go/xml/xml_test.go @@ -9,6 +9,7 @@ import ( "io" "os" "reflect" + "strings" "testing" ) @@ -96,6 +97,19 @@ var cookedTokens = []Token{ Comment([]byte(" missing final newline ")), } +const testInputAltEncoding = ` + +VALUE` + +var rawTokensAltEncoding = []Token{ + CharData([]byte("\n")), + ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, + CharData([]byte("\n")), + StartElement{Name{"", "tag"}, nil}, + CharData([]byte("value")), + EndElement{Name{"", "tag"}}, +} + var xmlInput = []string{ // unexpected EOF cases "<", @@ -173,7 +187,64 @@ func StringReader(s string) io.Reader { return &stringReader{s, 0} } func TestRawToken(t *testing.T) { p := NewParser(StringReader(testInput)) + testRawToken(t, p, rawTokens) +} + +type downCaser struct { + t *testing.T + r io.ByteReader +} +func (d *downCaser) ReadByte() (c byte, err os.Error) { + c, err = d.r.ReadByte() + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + return +} + +func (d *downCaser) Read(p []byte) (int, os.Error) { + d.t.Fatalf("unexpected Read call on downCaser reader") + return 0, os.EINVAL +} + +func TestRawTokenAltEncoding(t *testing.T) { + sawEncoding := "" + p := NewParser(StringReader(testInputAltEncoding)) + p.CharsetReader = func(charset string, input io.Reader) (io.Reader, os.Error) { + sawEncoding = charset + if charset != "x-testing-uppercase" { + t.Fatalf("unexpected charset %q", charset) + } + return &downCaser{t, input.(io.ByteReader)}, nil + } + testRawToken(t, p, rawTokensAltEncoding) +} + +func TestRawTokenAltEncodingNoConverter(t *testing.T) { + p := NewParser(StringReader(testInputAltEncoding)) + token, err := p.RawToken() + if token == nil { + t.Fatalf("expected a token on first RawToken call") + } + if err != nil { + t.Fatal(err) + } + token, err = p.RawToken() + if token != nil { + t.Errorf("expected a nil token; got %#v", token) + } + if err == nil { + t.Fatalf("expected an error on second RawToken call") + } + const encoding = "x-testing-uppercase" + if !strings.Contains(err.String(), encoding) { + t.Errorf("expected error to contain %q; got error: %v", + encoding, err) + } +} + +func testRawToken(t *testing.T, p *Parser, rawTokens []Token) { for i, want := range rawTokens { have, err := p.RawToken() if err != nil { @@ -258,47 +329,51 @@ func TestSyntax(t *testing.T) { } type allScalars struct { - True1 bool - True2 bool - False1 bool - False2 bool - Int int - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint int - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Uintptr uintptr - Float32 float32 - Float64 float64 - String string + True1 bool + True2 bool + False1 bool + False2 bool + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint int + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Uintptr uintptr + Float32 float32 + Float64 float64 + String string + PtrString *string } var all = allScalars{ - True1: true, - True2: true, - False1: false, - False2: false, - Int: 1, - Int8: -2, - Int16: 3, - Int32: -4, - Int64: 5, - Uint: 6, - Uint8: 7, - Uint16: 8, - Uint32: 9, - Uint64: 10, - Uintptr: 11, - Float32: 13.0, - Float64: 14.0, - String: "15", + True1: true, + True2: true, + False1: false, + False2: false, + Int: 1, + Int8: -2, + Int16: 3, + Int32: -4, + Int64: 5, + Uint: 6, + Uint8: 7, + Uint16: 8, + Uint32: 9, + Uint64: 10, + Uintptr: 11, + Float32: 13.0, + Float64: 14.0, + String: "15", + PtrString: &sixteen, } +var sixteen = "16" + const testScalarsInput = ` true 1 @@ -319,6 +394,7 @@ const testScalarsInput = ` 13.0 14.0 15 + 16 ` func TestAllScalars(t *testing.T) { @@ -330,7 +406,7 @@ func TestAllScalars(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(a, all) { - t.Errorf("expected %+v got %+v", all, a) + t.Errorf("have %+v want %+v", a, all) } } @@ -483,3 +559,26 @@ func TestDisallowedCharacters(t *testing.T) { } } } + +type procInstEncodingTest struct { + expect, got string +} + +var procInstTests = []struct { + input, expect string +}{ + {`version="1.0" encoding="utf-8"`, "utf-8"}, + {`version="1.0" encoding='utf-8'`, "utf-8"}, + {`version="1.0" encoding='utf-8' `, "utf-8"}, + {`version="1.0" encoding=utf-8`, ""}, + {`encoding="FOO" `, "FOO"}, +} + +func TestProcInstEncoding(t *testing.T) { + for _, test := range procInstTests { + got := procInstEncoding(test.input) + if got != test.expect { + t.Errorf("procInstEncoding(%q) = %q; want %q", test.input, got, test.expect) + } + } +} diff --git a/libgo/mksysinfo.sh b/libgo/mksysinfo.sh index fecaeb6b312..19885fb000d 100755 --- a/libgo/mksysinfo.sh +++ b/libgo/mksysinfo.sh @@ -74,6 +74,8 @@ cat > sysinfo.c < #endif #include +#include +#include EOF ${CC} -fdump-go-spec=gen-sysinfo.go -std=gnu99 -S -o sysinfo.s sysinfo.c @@ -403,6 +405,10 @@ else echo "type Rusage struct {}" >> ${OUT} fi +# The RUSAGE constants. +grep '^const _RUSAGE_' gen-sysinfo.go | \ + sed -e 's/^\(const \)_\(RUSAGE_[^= ]*\)\(.*\)$/\1\2 = _\2/' >> ${OUT} + # The utsname struct. grep '^type _utsname ' gen-sysinfo.go | \ sed -e 's/_utsname/Utsname/' \ @@ -456,4 +462,20 @@ if test "$fd_set" != ""; then fi echo "type fds_bits_type $fds_bits_type" >> ${OUT} +# The addrinfo struct. +grep '^type _addrinfo ' gen-sysinfo.go | \ + sed -e 's/_addrinfo/Addrinfo/g' \ + -e 's/ ai_/ Ai_/g' \ + >> ${OUT} + +# The addrinfo flags. +grep '^const _AI_' gen-sysinfo.go | \ + sed -e 's/^\(const \)_\(AI_[^= ]*\)\(.*\)$/\1\2 = _\2/' >> ${OUT} + +# The passwd struct. +grep '^type _passwd ' gen-sysinfo.go | \ + sed -e 's/_passwd/Passwd/' \ + -e 's/ pw_/ Pw_/g' \ + >> ${OUT} + exit $? diff --git a/libgo/runtime/chan.goc b/libgo/runtime/chan.goc index 9326f2689ae..acfff859ee0 100644 --- a/libgo/runtime/chan.goc +++ b/libgo/runtime/chan.goc @@ -13,7 +13,8 @@ typedef struct __go_channel chan; /* Do a channel receive with closed status. */ func chanrecv2(c *chan, val *byte) (received bool) { - if (c->element_size > 8) { + uintptr_t element_size = c->element_type->__size; + if (element_size > 8) { return __go_receive_big(c, val, 0); } else { union { @@ -23,10 +24,9 @@ func chanrecv2(c *chan, val *byte) (received bool) { u.v = __go_receive_small_closed(c, 0, &received); #ifndef WORDS_BIGENDIAN - __builtin_memcpy(val, u.b, c->element_size); + __builtin_memcpy(val, u.b, element_size); #else - __builtin_memcpy(val, u.b + 8 - c->element_size, - c->element_size); + __builtin_memcpy(val, u.b + 8 - element_size, element_size); #endif return received; } @@ -35,7 +35,8 @@ func chanrecv2(c *chan, val *byte) (received bool) { /* Do a channel receive with closed status for a select statement. */ func chanrecv3(c *chan, val *byte) (received bool) { - if (c->element_size > 8) { + uintptr_t element_size = c->element_type->__size; + if (element_size > 8) { return __go_receive_big(c, val, 1); } else { union { @@ -45,10 +46,9 @@ func chanrecv3(c *chan, val *byte) (received bool) { u.v = __go_receive_small_closed(c, 1, &received); #ifndef WORDS_BIGENDIAN - __builtin_memcpy(val, u.b, c->element_size); + __builtin_memcpy(val, u.b, element_size); #else - __builtin_memcpy(val, u.b + 8 - c->element_size, - c->element_size); + __builtin_memcpy(val, u.b + 8 - element_size, element_size); #endif return received; } diff --git a/libgo/runtime/channel.h b/libgo/runtime/channel.h index 9dcaf7fcdbb..d4f1632a449 100644 --- a/libgo/runtime/channel.h +++ b/libgo/runtime/channel.h @@ -7,6 +7,8 @@ #include #include +#include "go-type.h" + /* This structure is used when a select is waiting for a synchronous channel. */ @@ -34,8 +36,8 @@ struct __go_channel /* A condition variable. This is signalled when data is added to the channel and when data is removed from the channel. */ pthread_cond_t cond; - /* The size of elements on this channel. */ - size_t element_size; + /* The type of elements on this channel. */ + const struct __go_type_descriptor *element_type; /* True if a goroutine is waiting to send on a synchronous channel. */ _Bool waiting_to_send; @@ -82,7 +84,8 @@ typedef struct __go_channel __go_channel; acquired while this mutex is held. */ extern pthread_mutex_t __go_select_data_mutex; -extern struct __go_channel *__go_new_channel (uintptr_t, uintptr_t); +extern struct __go_channel * +__go_new_channel (const struct __go_type_descriptor *, uintptr_t); extern _Bool __go_synch_with_select (struct __go_channel *, _Bool); diff --git a/libgo/runtime/go-eface-compare.c b/libgo/runtime/go-eface-compare.c index c90177e2085..673440542f3 100644 --- a/libgo/runtime/go-eface-compare.c +++ b/libgo/runtime/go-eface-compare.c @@ -4,6 +4,7 @@ Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ +#include "go-panic.h" #include "interface.h" /* Compare two interface values. Return 0 for equal, not zero for not @@ -16,6 +17,11 @@ __go_empty_interface_compare (struct __go_empty_interface left, const struct __go_type_descriptor *left_descriptor; left_descriptor = left.__type_descriptor; + + if (((uintptr_t) left_descriptor & reflectFlags) != 0 + || ((uintptr_t) right.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); + if (left_descriptor == NULL && right.__type_descriptor == NULL) return 0; if (left_descriptor == NULL || right.__type_descriptor == NULL) diff --git a/libgo/runtime/go-eface-val-compare.c b/libgo/runtime/go-eface-val-compare.c index 319ede24301..d754cc5f36e 100644 --- a/libgo/runtime/go-eface-val-compare.c +++ b/libgo/runtime/go-eface-val-compare.c @@ -4,6 +4,7 @@ Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ +#include "go-panic.h" #include "go-type.h" #include "interface.h" @@ -19,6 +20,8 @@ __go_empty_interface_value_compare ( const struct __go_type_descriptor *left_descriptor; left_descriptor = left.__type_descriptor; + if (((uintptr_t) left_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); if (left_descriptor == NULL) return 1; if (!__go_type_descriptors_equal (left_descriptor, right_descriptor)) diff --git a/libgo/runtime/go-interface-eface-compare.c b/libgo/runtime/go-interface-eface-compare.c index 9de8424acc1..ff69a2749d4 100644 --- a/libgo/runtime/go-interface-eface-compare.c +++ b/libgo/runtime/go-interface-eface-compare.c @@ -4,6 +4,7 @@ Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ +#include "go-panic.h" #include "interface.h" /* Compare a non-empty interface value with an empty interface value. @@ -16,6 +17,8 @@ __go_interface_empty_compare (struct __go_interface left, { const struct __go_type_descriptor *left_descriptor; + if (((uintptr_t) right.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); if (left.__methods == NULL && right.__type_descriptor == NULL) return 0; if (left.__methods == NULL || right.__type_descriptor == NULL) diff --git a/libgo/runtime/go-new-channel.c b/libgo/runtime/go-new-channel.c index 028715e3b1d..e440e873652 100644 --- a/libgo/runtime/go-new-channel.c +++ b/libgo/runtime/go-new-channel.c @@ -13,12 +13,16 @@ #include "channel.h" struct __go_channel* -__go_new_channel (uintptr_t element_size, uintptr_t entries) +__go_new_channel (const struct __go_type_descriptor *element_type, + uintptr_t entries) { + uintptr_t element_size; struct __go_channel* ret; size_t alloc_size; int i; + element_size = element_type->__size; + if ((uintptr_t) (int) entries != entries || entries > (uintptr_t) -1 / element_size) __go_panic_msg ("chan size out of range"); @@ -40,7 +44,7 @@ __go_new_channel (uintptr_t element_size, uintptr_t entries) __go_assert (i == 0); i = pthread_cond_init (&ret->cond, NULL); __go_assert (i == 0); - ret->element_size = element_size; + ret->element_type = element_type; ret->waiting_to_send = 0; ret->waiting_to_receive = 0; ret->selected_for_send = 0; diff --git a/libgo/runtime/go-rec-big.c b/libgo/runtime/go-rec-big.c index fd3923ce272..580ccb06719 100644 --- a/libgo/runtime/go-rec-big.c +++ b/libgo/runtime/go-rec-big.c @@ -15,23 +15,24 @@ _Bool __go_receive_big (struct __go_channel *channel, void *val, _Bool for_select) { + uintptr_t element_size; size_t alloc_size; size_t offset; if (channel == NULL) __go_panic_msg ("receive from nil channel"); - alloc_size = ((channel->element_size + sizeof (uint64_t) - 1) - / sizeof (uint64_t)); + element_size = channel->element_type->__size; + alloc_size = (element_size + sizeof (uint64_t) - 1) / sizeof (uint64_t); if (!__go_receive_acquire (channel, for_select)) { - __builtin_memset (val, 0, channel->element_size); + __builtin_memset (val, 0, element_size); return 0; } offset = channel->next_fetch * alloc_size; - __builtin_memcpy (val, &channel->data[offset], channel->element_size); + __builtin_memcpy (val, &channel->data[offset], element_size); __go_receive_release (channel); diff --git a/libgo/runtime/go-rec-nb-big.c b/libgo/runtime/go-rec-nb-big.c index 78db587345f..8c315b19af7 100644 --- a/libgo/runtime/go-rec-nb-big.c +++ b/libgo/runtime/go-rec-nb-big.c @@ -14,23 +14,24 @@ _Bool __go_receive_nonblocking_big (struct __go_channel* channel, void *val, _Bool *closed) { + uintptr_t element_size; size_t alloc_size; size_t offset; - alloc_size = ((channel->element_size + sizeof (uint64_t) - 1) - / sizeof (uint64_t)); + element_size = channel->element_type->__size; + alloc_size = (element_size + sizeof (uint64_t) - 1) / sizeof (uint64_t); int data = __go_receive_nonblocking_acquire (channel); if (data != RECEIVE_NONBLOCKING_ACQUIRE_DATA) { - __builtin_memset (val, 0, channel->element_size); + __builtin_memset (val, 0, element_size); if (closed != NULL) *closed = data == RECEIVE_NONBLOCKING_ACQUIRE_CLOSED; return 0; } offset = channel->next_fetch * alloc_size; - __builtin_memcpy (val, &channel->data[offset], channel->element_size); + __builtin_memcpy (val, &channel->data[offset], element_size); __go_receive_release (channel); diff --git a/libgo/runtime/go-rec-nb-small.c b/libgo/runtime/go-rec-nb-small.c index 29cad9af620..eb0a25e12b7 100644 --- a/libgo/runtime/go-rec-nb-small.c +++ b/libgo/runtime/go-rec-nb-small.c @@ -94,9 +94,11 @@ __go_receive_nonblocking_acquire (struct __go_channel *channel) struct __go_receive_nonblocking_small __go_receive_nonblocking_small (struct __go_channel *channel) { + uintptr_t element_size; struct __go_receive_nonblocking_small ret; - __go_assert (channel->element_size <= sizeof (uint64_t)); + element_size = channel->element_type->__size; + __go_assert (element_size <= sizeof (uint64_t)); int data = __go_receive_nonblocking_acquire (channel); if (data != RECEIVE_NONBLOCKING_ACQUIRE_DATA) diff --git a/libgo/runtime/go-rec-small.c b/libgo/runtime/go-rec-small.c index 019b20a456b..946a18c0c8e 100644 --- a/libgo/runtime/go-rec-small.c +++ b/libgo/runtime/go-rec-small.c @@ -266,12 +266,14 @@ uint64_t __go_receive_small_closed (struct __go_channel *channel, _Bool for_select, _Bool *received) { + uintptr_t element_size; uint64_t ret; if (channel == NULL) __go_panic_msg ("receive from nil channel"); - __go_assert (channel->element_size <= sizeof (uint64_t)); + element_size = channel->element_type->__size; + __go_assert (element_size <= sizeof (uint64_t)); if (!__go_receive_acquire (channel, for_select)) { diff --git a/libgo/runtime/go-reflect-call.c b/libgo/runtime/go-reflect-call.c index 6ae749f9a56..a769142c3df 100644 --- a/libgo/runtime/go-reflect-call.c +++ b/libgo/runtime/go-reflect-call.c @@ -14,9 +14,32 @@ #include "go-type.h" #include "runtime.h" -/* Forward declaration. */ - -static ffi_type *go_type_to_ffi (const struct __go_type_descriptor *); +/* The functions in this file are only called from reflect_call. As + reflect_call calls a libffi function, which will be compiled + without -fsplit-stack, it will always run with a large stack. */ + +static ffi_type *go_array_to_ffi (const struct __go_array_type *) + __attribute__ ((no_split_stack)); +static ffi_type *go_slice_to_ffi (const struct __go_slice_type *) + __attribute__ ((no_split_stack)); +static ffi_type *go_struct_to_ffi (const struct __go_struct_type *) + __attribute__ ((no_split_stack)); +static ffi_type *go_string_to_ffi (void) __attribute__ ((no_split_stack)); +static ffi_type *go_interface_to_ffi (void) __attribute__ ((no_split_stack)); +static ffi_type *go_complex_to_ffi (ffi_type *) + __attribute__ ((no_split_stack)); +static ffi_type *go_type_to_ffi (const struct __go_type_descriptor *) + __attribute__ ((no_split_stack)); +static ffi_type *go_func_return_ffi (const struct __go_func_type *) + __attribute__ ((no_split_stack)); +static void go_func_to_cif (const struct __go_func_type *, _Bool, _Bool, + ffi_cif *) + __attribute__ ((no_split_stack)); +static size_t go_results_size (const struct __go_func_type *) + __attribute__ ((no_split_stack)); +static void go_set_results (const struct __go_func_type *, unsigned char *, + void **) + __attribute__ ((no_split_stack)); /* Return an ffi_type for a Go array type. The libffi library does not have any builtin support for passing arrays as values. We work @@ -31,7 +54,6 @@ go_array_to_ffi (const struct __go_array_type *descriptor) uintptr_t i; ret = (ffi_type *) __go_alloc (sizeof (ffi_type)); - __builtin_memset (ret, 0, sizeof (ffi_type)); ret->type = FFI_TYPE_STRUCT; len = descriptor->__len; ret->elements = (ffi_type **) __go_alloc ((len + 1) * sizeof (ffi_type *)); @@ -52,7 +74,6 @@ go_slice_to_ffi ( ffi_type *ret; ret = (ffi_type *) __go_alloc (sizeof (ffi_type)); - __builtin_memset (ret, 0, sizeof (ffi_type)); ret->type = FFI_TYPE_STRUCT; ret->elements = (ffi_type **) __go_alloc (4 * sizeof (ffi_type *)); ret->elements[0] = &ffi_type_pointer; @@ -73,7 +94,6 @@ go_struct_to_ffi (const struct __go_struct_type *descriptor) int i; ret = (ffi_type *) __go_alloc (sizeof (ffi_type)); - __builtin_memset (ret, 0, sizeof (ffi_type)); ret->type = FFI_TYPE_STRUCT; field_count = descriptor->__fields.__count; fields = (const struct __go_struct_field *) descriptor->__fields.__values; @@ -237,7 +257,6 @@ go_func_return_ffi (const struct __go_func_type *func) return go_type_to_ffi (types[0]); ret = (ffi_type *) __go_alloc (sizeof (ffi_type)); - __builtin_memset (ret, 0, sizeof (ffi_type)); ret->type = FFI_TYPE_STRUCT; ret->elements = (ffi_type **) __go_alloc ((count + 1) * sizeof (ffi_type *)); for (i = 0; i < count; ++i) @@ -251,7 +270,7 @@ go_func_return_ffi (const struct __go_func_type *func) static void go_func_to_cif (const struct __go_func_type *func, _Bool is_interface, - ffi_cif *cif) + _Bool is_method, ffi_cif *cif) { int num_params; const struct __go_type_descriptor **in_types; @@ -268,10 +287,19 @@ go_func_to_cif (const struct __go_func_type *func, _Bool is_interface, num_args = num_params + (is_interface ? 1 : 0); args = (ffi_type **) __go_alloc (num_args * sizeof (ffi_type *)); + i = 0; + off = 0; if (is_interface) - args[0] = &ffi_type_pointer; - off = is_interface ? 1 : 0; - for (i = 0; i < num_params; ++i) + { + args[0] = &ffi_type_pointer; + off = 1; + } + else if (is_method) + { + args[0] = &ffi_type_pointer; + i = 1; + } + for (; i < num_params; ++i) args[i + off] = go_type_to_ffi (in_types[i]); rettype = go_func_return_ffi (func); @@ -354,13 +382,14 @@ go_set_results (const struct __go_func_type *func, unsigned char *call_result, void reflect_call (const struct __go_func_type *func_type, const void *func_addr, - _Bool is_interface, void **params, void **results) + _Bool is_interface, _Bool is_method, void **params, + void **results) { ffi_cif cif; unsigned char *call_result; __go_assert (func_type->__common.__code == GO_FUNC); - go_func_to_cif (func_type, is_interface, &cif); + go_func_to_cif (func_type, is_interface, is_method, &cif); call_result = (unsigned char *) malloc (go_results_size (func_type)); diff --git a/libgo/runtime/go-reflect-chan.c b/libgo/runtime/go-reflect-chan.c index c9ccda7cefc..d568024b3df 100644 --- a/libgo/runtime/go-reflect-chan.c +++ b/libgo/runtime/go-reflect-chan.c @@ -8,30 +8,54 @@ #include #include "config.h" +#include "go-alloc.h" +#include "go-assert.h" +#include "go-panic.h" #include "go-type.h" #include "channel.h" /* This file implements support for reflection on channels. These functions are called from reflect/value.go. */ -extern unsigned char *makechan (const struct __go_type_descriptor *, uint32_t) +extern uintptr_t makechan (const struct __go_type_descriptor *, uint32_t) asm ("libgo_reflect.reflect.makechan"); -unsigned char * +uintptr_t makechan (const struct __go_type_descriptor *typ, uint32_t size) { - return (unsigned char *) __go_new_channel (typ->__size, size); + struct __go_channel *channel; + void *ret; + + __go_assert (typ->__code == GO_CHAN); + typ = ((const struct __go_channel_type *) typ)->__element_type; + + channel = __go_new_channel (typ, size); + + ret = __go_alloc (sizeof (void *)); + __builtin_memcpy (ret, &channel, sizeof (void *)); + return (uintptr_t) ret; } -extern void chansend (unsigned char *, unsigned char *, _Bool *) +extern _Bool chansend (uintptr_t, uintptr_t, _Bool) asm ("libgo_reflect.reflect.chansend"); -void -chansend (unsigned char *ch, unsigned char *val, _Bool *selected) +_Bool +chansend (uintptr_t ch, uintptr_t val_i, _Bool nb) { struct __go_channel *channel = (struct __go_channel *) ch; + uintptr_t element_size; + void *pv; + + if (channel == NULL) + __go_panic_msg ("send to nil channel"); + + if (__go_is_pointer_type (channel->element_type)) + pv = &val_i; + else + pv = (void *) val_i; - if (channel->element_size <= sizeof (uint64_t)) + element_size = channel->element_type->__size; + if (element_size <= sizeof (uint64_t)) { union { @@ -41,35 +65,60 @@ chansend (unsigned char *ch, unsigned char *val, _Bool *selected) __builtin_memset (u.b, 0, sizeof (uint64_t)); #ifndef WORDS_BIGENDIAN - __builtin_memcpy (u.b, val, channel->element_size); + __builtin_memcpy (u.b, pv, element_size); #else - __builtin_memcpy (u.b + sizeof (uint64_t) - channel->element_size, val, - channel->element_size); + __builtin_memcpy (u.b + sizeof (uint64_t) - element_size, pv, + element_size); #endif - if (selected == NULL) - __go_send_small (channel, u.v, 0); + if (nb) + return __go_send_nonblocking_small (channel, u.v); else - *selected = __go_send_nonblocking_small (channel, u.v); + { + __go_send_small (channel, u.v, 0); + return 1; + } } else { - if (selected == NULL) - __go_send_big (channel, val, 0); + if (nb) + return __go_send_nonblocking_big (channel, pv); else - *selected = __go_send_nonblocking_big (channel, val); + { + __go_send_big (channel, pv, 0); + return 1; + } } } -extern void chanrecv (unsigned char *, unsigned char *, _Bool *, _Bool *) +struct chanrecv_ret +{ + uintptr_t val; + _Bool selected; + _Bool received; +}; + +extern struct chanrecv_ret chanrecv (uintptr_t, _Bool) asm ("libgo_reflect.reflect.chanrecv"); -void -chanrecv (unsigned char *ch, unsigned char *val, _Bool *selected, - _Bool *received) +struct chanrecv_ret +chanrecv (uintptr_t ch, _Bool nb) { struct __go_channel *channel = (struct __go_channel *) ch; + void *pv; + uintptr_t element_size; + struct chanrecv_ret ret; + + element_size = channel->element_type->__size; - if (channel->element_size <= sizeof (uint64_t)) + if (__go_is_pointer_type (channel->element_type)) + pv = &ret.val; + else + { + pv = __go_alloc (element_size); + ret.val = (uintptr_t) pv; + } + + if (element_size <= sizeof (uint64_t)) { union { @@ -77,74 +126,73 @@ chanrecv (unsigned char *ch, unsigned char *val, _Bool *selected, uint64_t v; } u; - if (selected == NULL) - u.v = __go_receive_small_closed (channel, 0, received); + if (!nb) + { + u.v = __go_receive_small_closed (channel, 0, &ret.received); + ret.selected = 1; + } else { struct __go_receive_nonblocking_small s; s = __go_receive_nonblocking_small (channel); - *selected = s.__success || s.__closed; - if (received != NULL) - *received = s.__success; + ret.selected = s.__success || s.__closed; + ret.received = s.__success; u.v = s.__val; } #ifndef WORDS_BIGENDIAN - __builtin_memcpy (val, u.b, channel->element_size); + __builtin_memcpy (pv, u.b, element_size); #else - __builtin_memcpy (val, u.b + sizeof (uint64_t) - channel->element_size, - channel->element_size); + __builtin_memcpy (pv, u.b + sizeof (uint64_t) - element_size, + element_size); #endif } else { - if (selected == NULL) + if (!nb) { - _Bool success; - - success = __go_receive_big (channel, val, 0); - if (received != NULL) - *received = success; + ret.received = __go_receive_big (channel, pv, 0); + ret.selected = 1; } else { _Bool got; _Bool closed; - got = __go_receive_nonblocking_big (channel, val, &closed); - *selected = got || closed; - if (received != NULL) - *received = got; + got = __go_receive_nonblocking_big (channel, pv, &closed); + ret.selected = got || closed; + ret.received = got; } } + + return ret; } -extern void chanclose (unsigned char *) - asm ("libgo_reflect.reflect.chanclose"); +extern void chanclose (uintptr_t) asm ("libgo_reflect.reflect.chanclose"); void -chanclose (unsigned char *ch) +chanclose (uintptr_t ch) { struct __go_channel *channel = (struct __go_channel *) ch; __go_builtin_close (channel); } -extern int32_t chanlen (unsigned char *) asm ("libgo_reflect.reflect.chanlen"); +extern int32_t chanlen (uintptr_t) asm ("libgo_reflect.reflect.chanlen"); int32_t -chanlen (unsigned char *ch) +chanlen (uintptr_t ch) { struct __go_channel *channel = (struct __go_channel *) ch; return (int32_t) __go_chan_len (channel); } -extern int32_t chancap (unsigned char *) asm ("libgo_reflect.reflect.chancap"); +extern int32_t chancap (uintptr_t) asm ("libgo_reflect.reflect.chancap"); int32_t -chancap (unsigned char *ch) +chancap (uintptr_t ch) { struct __go_channel *channel = (struct __go_channel *) ch; diff --git a/libgo/runtime/go-reflect-map.c b/libgo/runtime/go-reflect-map.c index 67960dee41e..5559f6eadaf 100644 --- a/libgo/runtime/go-reflect-map.c +++ b/libgo/runtime/go-reflect-map.c @@ -8,69 +8,125 @@ #include #include "go-alloc.h" +#include "go-panic.h" #include "go-type.h" #include "map.h" /* This file implements support for reflection on maps. These functions are called from reflect/value.go. */ -extern _Bool mapaccess (unsigned char *, unsigned char *, unsigned char *) +struct mapaccess_ret +{ + uintptr_t val; + _Bool pres; +}; + +extern struct mapaccess_ret mapaccess (uintptr_t, uintptr_t) asm ("libgo_reflect.reflect.mapaccess"); -_Bool -mapaccess (unsigned char *m, unsigned char *key, unsigned char *val) +struct mapaccess_ret +mapaccess (uintptr_t m, uintptr_t key_i) { struct __go_map *map = (struct __go_map *) m; + void *key; + const struct __go_type_descriptor *key_descriptor; void *p; const struct __go_type_descriptor *val_descriptor; + struct mapaccess_ret ret; + void *val; + void *pv; + + if (map == NULL) + __go_panic_msg ("lookup in nil map"); + + key_descriptor = map->__descriptor->__map_descriptor->__key_type; + if (__go_is_pointer_type (key_descriptor)) + key = &key_i; + else + key = (void *) key_i; p = __go_map_index (map, key, 0); + + val_descriptor = map->__descriptor->__map_descriptor->__val_type; + if (__go_is_pointer_type (val_descriptor)) + { + val = NULL; + pv = &val; + } + else + { + val = __go_alloc (val_descriptor->__size); + pv = val; + } + if (p == NULL) - return 0; + ret.pres = 0; else { - val_descriptor = map->__descriptor->__map_descriptor->__val_type; - __builtin_memcpy (val, p, val_descriptor->__size); - return 1; + __builtin_memcpy (pv, p, val_descriptor->__size); + ret.pres = 1; } + + ret.val = (uintptr_t) val; + return ret; } -extern void mapassign (unsigned char *, unsigned char *, unsigned char *) +extern void mapassign (uintptr_t, uintptr_t, uintptr_t, _Bool) asm ("libgo_reflect.reflect.mapassign"); void -mapassign (unsigned char *m, unsigned char *key, unsigned char *val) +mapassign (uintptr_t m, uintptr_t key_i, uintptr_t val_i, _Bool pres) { struct __go_map *map = (struct __go_map *) m; + const struct __go_type_descriptor *key_descriptor; + void *key; - if (val == NULL) + if (map == NULL) + __go_panic_msg ("lookup in nil map"); + + key_descriptor = map->__descriptor->__map_descriptor->__key_type; + if (__go_is_pointer_type (key_descriptor)) + key = &key_i; + else + key = (void *) key_i; + + if (!pres) __go_map_delete (map, key); else { void *p; const struct __go_type_descriptor *val_descriptor; + void *pv; p = __go_map_index (map, key, 1); + val_descriptor = map->__descriptor->__map_descriptor->__val_type; - __builtin_memcpy (p, val, val_descriptor->__size); + if (__go_is_pointer_type (val_descriptor)) + pv = &val_i; + else + pv = (void *) val_i; + __builtin_memcpy (p, pv, val_descriptor->__size); } } -extern int32_t maplen (unsigned char *) +extern int32_t maplen (uintptr_t) asm ("libgo_reflect.reflect.maplen"); int32_t -maplen (unsigned char *m __attribute__ ((unused))) +maplen (uintptr_t m) { struct __go_map *map = (struct __go_map *) m; + + if (map == NULL) + return 0; return (int32_t) map->__element_count; } -extern unsigned char *mapiterinit (unsigned char *) +extern unsigned char *mapiterinit (uintptr_t) asm ("libgo_reflect.reflect.mapiterinit"); unsigned char * -mapiterinit (unsigned char *m) +mapiterinit (uintptr_t m) { struct __go_hash_iter *it; @@ -88,35 +144,67 @@ mapiternext (unsigned char *it) __go_mapiternext ((struct __go_hash_iter *) it); } -extern _Bool mapiterkey (unsigned char *, unsigned char *) +struct mapiterkey_ret +{ + uintptr_t key; + _Bool ok; +}; + +extern struct mapiterkey_ret mapiterkey (unsigned char *) asm ("libgo_reflect.reflect.mapiterkey"); -_Bool -mapiterkey (unsigned char *ita, unsigned char *key) +struct mapiterkey_ret +mapiterkey (unsigned char *ita) { struct __go_hash_iter *it = (struct __go_hash_iter *) ita; + struct mapiterkey_ret ret; if (it->entry == NULL) - return 0; + { + ret.key = 0; + ret.ok = 0; + } else { - __go_mapiter1 (it, key); - return 1; + const struct __go_type_descriptor *key_descriptor; + void *key; + void *pk; + + key_descriptor = it->map->__descriptor->__map_descriptor->__key_type; + if (__go_is_pointer_type (key_descriptor)) + { + key = NULL; + pk = &key; + } + else + { + key = __go_alloc (key_descriptor->__size); + pk = key; + } + + __go_mapiter1 (it, pk); + + ret.key = (uintptr_t) key; + ret.ok = 1; } + + return ret; } /* Make a new map. We have to build our own map descriptor. */ -extern unsigned char *makemap (const struct __go_map_type *) +extern uintptr_t makemap (const struct __go_map_type *) asm ("libgo_reflect.reflect.makemap"); -unsigned char * +uintptr_t makemap (const struct __go_map_type *t) { struct __go_map_descriptor *md; unsigned int o; const struct __go_type_descriptor *kt; const struct __go_type_descriptor *vt; + struct __go_map* map; + void *ret; /* FIXME: Reference count. */ md = (struct __go_map_descriptor *) __go_alloc (sizeof (*md)); @@ -135,5 +223,9 @@ makemap (const struct __go_map_type *t) o = (o + vt->__field_align - 1) & ~ (vt->__field_align - 1); md->__entry_size = o; - return (unsigned char *) __go_new_map (md, 0); + map = __go_new_map (md, 0); + + ret = __go_alloc (sizeof (void *)); + __builtin_memcpy (ret, &map, sizeof (void *)); + return (uintptr_t) ret; } diff --git a/libgo/runtime/go-reflect.c b/libgo/runtime/go-reflect.c index 9485c0979b6..bf13a11fae2 100644 --- a/libgo/runtime/go-reflect.c +++ b/libgo/runtime/go-reflect.c @@ -121,6 +121,9 @@ Reflect (struct __go_empty_interface e) { struct reflect_ret ret; + if (((uintptr_t) e.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); + if (e.__type_descriptor == NULL) { ret.rettype.__type_descriptor = NULL; @@ -166,6 +169,9 @@ Typeof (const struct __go_empty_interface e) { struct __go_empty_interface ret; + if (((uintptr_t) e.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); + if (e.__type_descriptor == NULL) { ret.__type_descriptor = NULL; diff --git a/libgo/runtime/go-send-big.c b/libgo/runtime/go-send-big.c index c2732639536..0b4aa046d12 100644 --- a/libgo/runtime/go-send-big.c +++ b/libgo/runtime/go-send-big.c @@ -12,19 +12,20 @@ void __go_send_big (struct __go_channel* channel, const void *val, _Bool for_select) { + uintptr_t element_size; size_t alloc_size; size_t offset; if (channel == NULL) __go_panic_msg ("send to nil channel"); - alloc_size = ((channel->element_size + sizeof (uint64_t) - 1) - / sizeof (uint64_t)); + element_size = channel->element_type->__size; + alloc_size = (element_size + sizeof (uint64_t) - 1) / sizeof (uint64_t); __go_send_acquire (channel, for_select); offset = channel->next_store * alloc_size; - __builtin_memcpy (&channel->data[offset], val, channel->element_size); + __builtin_memcpy (&channel->data[offset], val, element_size); __go_send_release (channel); } diff --git a/libgo/runtime/go-send-nb-big.c b/libgo/runtime/go-send-nb-big.c index 1d33dd6207b..12748591706 100644 --- a/libgo/runtime/go-send-nb-big.c +++ b/libgo/runtime/go-send-nb-big.c @@ -11,17 +11,18 @@ _Bool __go_send_nonblocking_big (struct __go_channel* channel, const void *val) { + uintptr_t element_size; size_t alloc_size; size_t offset; - alloc_size = ((channel->element_size + sizeof (uint64_t) - 1) - / sizeof (uint64_t)); + element_size = channel->element_type->__size; + alloc_size = (element_size + sizeof (uint64_t) - 1) / sizeof (uint64_t); if (!__go_send_nonblocking_acquire (channel)) return 0; offset = channel->next_store * alloc_size; - __builtin_memcpy (&channel->data[offset], val, channel->element_size); + __builtin_memcpy (&channel->data[offset], val, element_size); __go_send_release (channel); diff --git a/libgo/runtime/go-send-nb-small.c b/libgo/runtime/go-send-nb-small.c index 5c49a67ffb4..0a25ba369ea 100644 --- a/libgo/runtime/go-send-nb-small.c +++ b/libgo/runtime/go-send-nb-small.c @@ -93,7 +93,7 @@ __go_send_nonblocking_acquire (struct __go_channel *channel) _Bool __go_send_nonblocking_small (struct __go_channel *channel, uint64_t val) { - __go_assert (channel->element_size <= sizeof (uint64_t)); + __go_assert (channel->element_type->__size <= sizeof (uint64_t)); if (!__go_send_nonblocking_acquire (channel)) return 0; diff --git a/libgo/runtime/go-send-small.c b/libgo/runtime/go-send-small.c index 56f9470911f..8e21d368588 100644 --- a/libgo/runtime/go-send-small.c +++ b/libgo/runtime/go-send-small.c @@ -147,7 +147,7 @@ __go_send_small (struct __go_channel *channel, uint64_t val, _Bool for_select) if (channel == NULL) __go_panic_msg ("send to nil channel"); - __go_assert (channel->element_size <= sizeof (uint64_t)); + __go_assert (channel->element_type->__size <= sizeof (uint64_t)); __go_send_acquire (channel, for_select); diff --git a/libgo/runtime/go-setenv.c b/libgo/runtime/go-setenv.c new file mode 100644 index 00000000000..20f99399759 --- /dev/null +++ b/libgo/runtime/go-setenv.c @@ -0,0 +1,50 @@ +/* go-setenv.c -- set the C environment from Go. + + Copyright 2011 The Go Authors. All rights reserved. + Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. */ + +#include +#include + +#include "go-alloc.h" +#include "go-string.h" + +/* Set the C environment from Go. This is called by os.Setenv. */ + +void setenv_c (struct __go_string, struct __go_string) + __asm__ ("libgo_os.os.setenv_c"); + +void +setenv_c (struct __go_string k, struct __go_string v) +{ + const unsigned char *ks; + unsigned char *kn; + const unsigned char *vs; + unsigned char *vn; + + ks = k.__data; + kn = NULL; + if (ks[k.__length] != 0) + { + kn = __go_alloc (k.__length + 1); + __builtin_memcpy (kn, k.__data, k.__length); + ks = kn; + } + + vs = v.__data; + vn = NULL; + if (vs[v.__length] != 0) + { + vn = __go_alloc (v.__length + 1); + __builtin_memcpy (vn, v.__data, v.__length); + vs = vn; + } + + setenv ((const char *) ks, (const char *) vs, 1); + + if (kn != NULL) + __go_free (kn); + if (vn != NULL) + __go_free (vn); +} diff --git a/libgo/runtime/go-signal.c b/libgo/runtime/go-signal.c index edeeccc79eb..c16b058b79e 100644 --- a/libgo/runtime/go-signal.c +++ b/libgo/runtime/go-signal.c @@ -245,3 +245,26 @@ runtime_resetcpuprofiler(int32 hz) m->profilehz = hz; } + +/* Used by the os package to raise SIGPIPE. */ + +void os_sigpipe (void) __asm__ ("libgo_os.os.sigpipe"); + +void +os_sigpipe (void) +{ + struct sigaction sa; + int i; + + memset (&sa, 0, sizeof sa); + + sa.sa_handler = SIG_DFL; + + i = sigemptyset (&sa.sa_mask); + __go_assert (i == 0); + + if (sigaction (SIGPIPE, &sa, NULL) != 0) + abort (); + + raise (SIGPIPE); +} diff --git a/libgo/runtime/go-type-eface.c b/libgo/runtime/go-type-eface.c index 84ca05ee1dc..cea9c240a66 100644 --- a/libgo/runtime/go-type-eface.c +++ b/libgo/runtime/go-type-eface.c @@ -5,6 +5,7 @@ license that can be found in the LICENSE file. */ #include "interface.h" +#include "go-panic.h" #include "go-type.h" /* A hash function for an empty interface. */ @@ -43,6 +44,9 @@ __go_type_equal_empty_interface (const void *vv1, const void *vv2, v2 = (const struct __go_empty_interface *) vv2; v1_descriptor = v1->__type_descriptor; v2_descriptor = v2->__type_descriptor; + if (((uintptr_t) v1_descriptor & reflectFlags) != 0 + || ((uintptr_t) v2_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); if (v1_descriptor == NULL || v2_descriptor == NULL) return v1_descriptor == v2_descriptor; if (!__go_type_descriptors_equal (v1_descriptor, v2_descriptor)) diff --git a/libgo/runtime/go-type.h b/libgo/runtime/go-type.h index 76681217fdf..e048141e93e 100644 --- a/libgo/runtime/go-type.h +++ b/libgo/runtime/go-type.h @@ -149,6 +149,9 @@ struct __go_array_type /* The element type. */ struct __go_type_descriptor *__element_type; + /* The type of a slice of the same element type. */ + struct __go_type_descriptor *__slice_type; + /* The length of the array. */ uintptr_t __len; }; @@ -289,6 +292,15 @@ struct __go_struct_type struct __go_open_array __fields; }; +/* If an empty interface has these bits set in its type pointer, it + was copied from a reflect.Value and is not a valid empty + interface. */ + +enum +{ + reflectFlags = 3, +}; + /* Whether a type descriptor is a pointer. */ static inline _Bool diff --git a/libgo/runtime/go-unreflect.c b/libgo/runtime/go-unreflect.c index 88604854876..c4da5fdcf05 100644 --- a/libgo/runtime/go-unreflect.c +++ b/libgo/runtime/go-unreflect.c @@ -5,6 +5,7 @@ license that can be found in the LICENSE file. */ #include "go-alloc.h" +#include "go-panic.h" #include "go-type.h" #include "interface.h" @@ -19,6 +20,9 @@ Unreflect (struct __go_empty_interface type, void *object) { struct __go_empty_interface ret; + if (((uintptr_t) type.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); + /* FIXME: We should check __type_descriptor to verify that this is really a type descriptor. */ ret.__type_descriptor = type.__object; diff --git a/libgo/runtime/go-unsafe-new.c b/libgo/runtime/go-unsafe-new.c index e55d415becb..b0a65c42263 100644 --- a/libgo/runtime/go-unsafe-new.c +++ b/libgo/runtime/go-unsafe-new.c @@ -5,6 +5,7 @@ license that can be found in the LICENSE file. */ #include "go-alloc.h" +#include "go-panic.h" #include "go-type.h" #include "interface.h" @@ -20,6 +21,9 @@ New (struct __go_empty_interface type) { const struct __go_type_descriptor *descriptor; + if (((uintptr_t) type.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); + /* FIXME: We should check __type_descriptor to verify that this is really a type descriptor. */ descriptor = (const struct __go_type_descriptor *) type.__object; diff --git a/libgo/runtime/go-unsafe-newarray.c b/libgo/runtime/go-unsafe-newarray.c index 3bea2829f78..5fd81ce2733 100644 --- a/libgo/runtime/go-unsafe-newarray.c +++ b/libgo/runtime/go-unsafe-newarray.c @@ -5,6 +5,7 @@ license that can be found in the LICENSE file. */ #include "go-alloc.h" +#include "go-panic.h" #include "go-type.h" #include "interface.h" @@ -21,6 +22,9 @@ NewArray (struct __go_empty_interface type, int n) { const struct __go_type_descriptor *descriptor; + if (((uintptr_t) type.__type_descriptor & reflectFlags) != 0) + __go_panic_msg ("invalid interface value"); + /* FIXME: We should check __type_descriptor to verify that this is really a type descriptor. */ descriptor = (const struct __go_type_descriptor *) type.__object; diff --git a/libgo/runtime/iface.goc b/libgo/runtime/iface.goc index 356b318cbc8..05e37736b88 100644 --- a/libgo/runtime/iface.goc +++ b/libgo/runtime/iface.goc @@ -3,6 +3,7 @@ // license that can be found in the LICENSE file. package runtime +#include "go-panic.h" #include "go-type.h" #include "interface.h" #define nil NULL @@ -33,6 +34,8 @@ func ifacetype(i interface) (d *const_descriptor) { // Convert an empty interface to an empty interface. func ifaceE2E2(e empty_interface) (ret empty_interface, ok bool) { + if(((uintptr_t)e.__type_descriptor&reflectFlags) != 0) + __go_panic_msg("invalid interface value"); ret = e; ok = ret.__type_descriptor != nil; } @@ -52,6 +55,8 @@ func ifaceI2E2(i interface) (ret empty_interface, ok bool) { // Convert an empty interface to a non-empty interface. func ifaceE2I2(inter *descriptor, e empty_interface) (ret interface, ok bool) { + if(((uintptr_t)e.__type_descriptor&reflectFlags) != 0) + __go_panic_msg("invalid interface value"); if (e.__type_descriptor == nil) { ret.__methods = nil; ret.__object = nil; @@ -81,6 +86,8 @@ func ifaceI2I2(inter *descriptor, i interface) (ret interface, ok bool) { // Convert an empty interface to a pointer type. func ifaceE2T2P(inter *descriptor, e empty_interface) (ret *void, ok bool) { + if(((uintptr_t)e.__type_descriptor&reflectFlags) != 0) + __go_panic_msg("invalid interface value"); if (!__go_type_descriptors_equal(inter, e.__type_descriptor)) { ret = nil; ok = 0; @@ -104,6 +111,8 @@ func ifaceI2T2P(inter *descriptor, i interface) (ret *void, ok bool) { // Convert an empty interface to a non-pointer type. func ifaceE2T2(inter *descriptor, e empty_interface, ret *void) (ok bool) { + if(((uintptr_t)e.__type_descriptor&reflectFlags) != 0) + __go_panic_msg("invalid interface value"); if (!__go_type_descriptors_equal(inter, e.__type_descriptor)) { __builtin_memset(ret, 0, inter->__size); ok = 0; diff --git a/libgo/runtime/malloc.goc b/libgo/runtime/malloc.goc index 196271abd04..b46995ae0b6 100644 --- a/libgo/runtime/malloc.goc +++ b/libgo/runtime/malloc.goc @@ -386,7 +386,7 @@ runtime_MHeap_SysAlloc(MHeap *h, uintptr n) return nil; if(p < h->arena_start || (uintptr)(p+n - h->arena_start) >= MaxArena32) { - runtime_printf("runtime: memory allocated by OS not in usable range"); + runtime_printf("runtime: memory allocated by OS not in usable range\n"); runtime_SysFree(p, n); return nil; } diff --git a/libgo/runtime/mcache.c b/libgo/runtime/mcache.c index ce65757587d..65d849c16ea 100644 --- a/libgo/runtime/mcache.c +++ b/libgo/runtime/mcache.c @@ -22,6 +22,8 @@ runtime_MCache_Alloc(MCache *c, int32 sizeclass, uintptr size, int32 zeroed) // Replenish using central lists. n = runtime_MCentral_AllocList(&runtime_mheap.central[sizeclass], runtime_class_to_transfercount[sizeclass], &first); + if(n == 0) + runtime_throw("out of memory"); l->list = first; l->nlist = n; c->size += n*size; diff --git a/libgo/runtime/mgc0.c b/libgo/runtime/mgc0.c index 27fc3cdcc4d..0f28f5f6bd8 100644 --- a/libgo/runtime/mgc0.c +++ b/libgo/runtime/mgc0.c @@ -90,6 +90,11 @@ scanblock(byte *b, int64 n) void **bw, **w, **ew; Workbuf *wbuf; + if((int64)(uintptr)n != n || n < 0) { + // runtime_printf("scanblock %p %lld\n", b, (long long)n); + runtime_throw("scanblock"); + } + // Memory arena parameters. arena_start = runtime_mheap.arena_start; @@ -602,7 +607,7 @@ runfinq(void* dummy) next = f->next; params[0] = &f->arg; - reflect_call(f->ft, (void*)f->fn, 0, params, nil); + reflect_call(f->ft, (void*)f->fn, 0, 0, params, nil); f->fn = nil; f->arg = nil; f->next = nil; diff --git a/libgo/runtime/mheap.c b/libgo/runtime/mheap.c index b36df258818..cc6b3aff423 100644 --- a/libgo/runtime/mheap.c +++ b/libgo/runtime/mheap.c @@ -181,9 +181,7 @@ MHeap_Grow(MHeap *h, uintptr npage) // Allocate a multiple of 64kB (16 pages). npage = (npage+15)&~15; ask = npage< (uintptr)(h->arena_end - h->arena_used)) - return false; - if(ask < HeapAllocChunk && HeapAllocChunk <= h->arena_end - h->arena_used) + if(ask < HeapAllocChunk) ask = HeapAllocChunk; v = runtime_MHeap_SysAlloc(h, ask); @@ -192,8 +190,10 @@ MHeap_Grow(MHeap *h, uintptr npage) ask = npage<__methods.__count == 0) { - // already an empty interface - *(Eface*)ret = *(Eface*)x; - return; + if(((uintptr)e.__type_descriptor&reflectFlags) != 0) + runtime_throw("invalid interface value"); + t = e.__type_descriptor; + if(t == nil) { + // explicit conversions require non-nil interface value. + newTypeAssertionError(nil, nil, inter, + nil, nil, inter->__reflection, + nil, &err); + __go_panic(err); } - xt = ((Eface*)x)->__type_descriptor; - if(xt == nil) { - // can assign nil to any interface - ((Iface*)ret)->__methods = nil; - ((Iface*)ret)->__object = nil; - return; - } - ((Iface*)ret)->__methods = __go_convert_interface(&t->__common, xt); - ((Iface*)ret)->__object = ((Eface*)x)->__object; + ret->__object = e.__object; + ret->__methods = __go_convert_interface(inter, t); } diff --git a/libgo/runtime/runtime.h b/libgo/runtime/runtime.h index 00ba40f6f28..bf07c1219ba 100644 --- a/libgo/runtime/runtime.h +++ b/libgo/runtime/runtime.h @@ -194,8 +194,8 @@ void runtime_resetcpuprofiler(int32); void runtime_setcpuprofilerate(void(*)(uintptr*, int32), int32); struct __go_func_type; -void reflect_call(const struct __go_func_type *, const void *, _Bool, void **, - void **) +void reflect_call(const struct __go_func_type *, const void *, _Bool, _Bool, + void **, void **) asm ("libgo_reflect.reflect.call"); #ifdef __rtems__ diff --git a/libgo/syscalls/stringbyte.go b/libgo/syscalls/stringbyte.go index b673c9b02bb..17619536f2e 100644 --- a/libgo/syscalls/stringbyte.go +++ b/libgo/syscalls/stringbyte.go @@ -6,6 +6,8 @@ package syscall +import "unsafe" + // StringByteSlice returns a NUL-terminated slice of bytes // containing the text of s. func StringByteSlice(s string) []byte { @@ -22,3 +24,14 @@ func StringBytePtr(s string) *byte { p := StringByteSlice(s); return &p[0]; } + +// BytePtrToString takes a NUL-terminated array of bytes and convert +// it to a Go string. +func BytePtrToString(p *byte) string { + a := (*[10000]byte)(unsafe.Pointer(p)) + i := 0 + for a[i] != 0 { + i++ + } + return string(a[:i]) +} diff --git a/libgo/syscalls/syscall_unix.go b/libgo/syscalls/syscall_unix.go index b5a660e93d1..e633ea19146 100644 --- a/libgo/syscalls/syscall_unix.go +++ b/libgo/syscalls/syscall_unix.go @@ -56,3 +56,12 @@ func munmap(addr uintptr, length uintptr) (errno int) { } return } + +func libc_getrusage(who int, rusage *Rusage) int __asm__ ("getrusage") + +func Getrusage(who int, rusage *Rusage) (errno int) { + if libc_getrusage(who, rusage) < 0 { + errno = GetErrno() + } + return +}