From 1088ff74cb906a71514554c4d4493749a7884e1a Mon Sep 17 00:00:00 2001 From: "Sergei G." Date: Mon, 3 Nov 2025 19:35:22 +0400 Subject: [PATCH 1/5] io: add Err field to LimitedReader Add an Err field to LimitedReader that allows callers to return a custom error when the read limit is exceeded, instead of always returning EOF. When Err is set to a non-nil, non-EOF value, and the limit is reached, LimitedReader.Read probes the underlying reader with a 1-byte read to distinguish two cases: stream had exactly N bytes (returns EOF), or stream has more data (returns the custom Err). The probe result is cached using negative N values to avoid repeated reads. When Err is nil or EOF, Read returns EOF, maintaining backward compatibility. Zero-length reads return (0, nil) without side effects. Fixes #51115 --- api/next/51115.txt | 1 + src/io/example_test.go | 24 +++++ src/io/io.go | 66 ++++++++++--- src/io/io_test.go | 212 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 290 insertions(+), 13 deletions(-) create mode 100644 api/next/51115.txt diff --git a/api/next/51115.txt b/api/next/51115.txt new file mode 100644 index 00000000000000..0ce24b4ed04e71 --- /dev/null +++ b/api/next/51115.txt @@ -0,0 +1 @@ +pkg io, type LimitedReader struct, Err error #51115 diff --git a/src/io/example_test.go b/src/io/example_test.go index 818020e9dec6dd..f24debcac6adee 100644 --- a/src/io/example_test.go +++ b/src/io/example_test.go @@ -5,6 +5,7 @@ package io_test import ( + "errors" "fmt" "io" "log" @@ -121,6 +122,29 @@ func ExampleLimitReader() { // some } +func ExampleLimitedReader_Err() { + r := strings.NewReader("some io.Reader stream to be read\n") + sentinel := errors.New("read limit reached") + lr := &io.LimitedReader{R: r, N: 4, Err: sentinel} + + buf := make([]byte, 10) + n, err := lr.Read(buf) + if err != nil { + log.Fatal(err) + } + fmt.Printf("read %d bytes: %q\n", n, buf[:n]) + + // try to read more and get the custom error + n, err = lr.Read(buf) + if errors.Is(err, sentinel) { + fmt.Println("error:", err) + } + + // Output: + // read 4 bytes: "some" + // error: read limit reached +} + func ExampleMultiReader() { r1 := strings.NewReader("first reader ") r2 := strings.NewReader("second reader ") diff --git a/src/io/io.go b/src/io/io.go index 00edcde763a55a..54aac8c846e77e 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -457,28 +457,66 @@ func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) { // LimitReader returns a Reader that reads from r // but stops with EOF after n bytes. -// The underlying implementation is a *LimitedReader. -func LimitReader(r Reader, n int64) Reader { return &LimitedReader{r, n} } +// To return a custom error when the limit is reached, construct +// a *LimitedReader directly with the desired Err field. +func LimitReader(r Reader, n int64) Reader { return &LimitedReader{R: r, N: n} } // A LimitedReader reads from R but limits the amount of // data returned to just N bytes. Each call to Read // updates N to reflect the new amount remaining. -// Read returns EOF when N <= 0 or when the underlying R returns EOF. +// +// Negative values of N mean that the limit has been exceeded. +// Read returns Err when more than N bytes are read from R. +// If Err is nil or EOF, Read returns EOF instead. type LimitedReader struct { - R Reader // underlying reader - N int64 // max bytes remaining + R Reader // underlying reader + N int64 // max bytes remaining + Err error // error to return when limit is exceeded; defaults to EOF if nil } func (l *LimitedReader) Read(p []byte) (n int, err error) { - if l.N <= 0 { + if len(p) == 0 { + return 0, nil + } + // We use negative l.N values to signal that we've exceeded the limit and cached the result: + // -1 means more data is available + // -2 means hit EOF exactly + + if l.N > 0 { + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return + } + + if l.N < 0 { + if l.N == -1 && l.Err != nil && l.Err != EOF { + return 0, l.Err // limit was exceeded + } + return 0, EOF // stream was exactly N bytes, or already past limit + } + + // At limit (N == 0) - need to determine if stream has more data + + if l.Err == nil || l.Err == EOF { return 0, EOF } - if int64(len(p)) > l.N { - p = p[0:l.N] + + // Probe with one byte to distinguish two cases: + // - Stream had exactly N bytes -> return EOF + // - Stream has more than N bytes -> return custom error + // We can't tell without reading ahead. This probe permanently consumes + // a byte from R, so we cache the result in N to avoid re-probing. + var probe [1]byte + probeN, probeErr := l.R.Read(probe[:]) + if probeN > 0 || (probeErr != nil && probeErr != EOF) { + l.N = -1 // more data available, limit exceeded + return 0, l.Err } - n, err = l.R.Read(p) - l.N -= int64(n) - return + l.N = -2 // hit EOF, stream was exactly N bytes + return 0, EOF } // NewSectionReader returns a [SectionReader] that reads from r @@ -518,8 +556,10 @@ func (s *SectionReader) Read(p []byte) (n int, err error) { return } -var errWhence = errors.New("Seek: invalid whence") -var errOffset = errors.New("Seek: invalid offset") +var ( + errWhence = errors.New("Seek: invalid whence") + errOffset = errors.New("Seek: invalid offset") +) func (s *SectionReader) Seek(offset int64, whence int) (int64, error) { switch whence { diff --git a/src/io/io_test.go b/src/io/io_test.go index 38bec8243e7baa..8d517f6fc08119 100644 --- a/src/io/io_test.go +++ b/src/io/io_test.go @@ -692,3 +692,215 @@ func TestOffsetWriter_Write(t *testing.T) { checkContent(name, f) }) } + +var errLimit = errors.New("limit exceeded") + +func TestLimitedReader(t *testing.T) { + src := strings.NewReader("abc") + r := LimitReader(src, 5) + lr, ok := r.(*LimitedReader) + if !ok { + t.Fatalf("LimitReader should return *LimitedReader, got %T", r) + } + if lr.R != src || lr.N != 5 || lr.Err != nil { + t.Fatalf("LimitReader() = {R: %v, N: %d, Err: %v}, want {R: %v, N: 5, Err: nil}", lr.R, lr.N, lr.Err, src) + } + + t.Run("WithoutCustomErr", func(t *testing.T) { + tests := []struct { + name string + data string + limit int64 + want1N int + want1E error + want2E error + }{ + {"UnderLimit", "hello", 10, 5, nil, EOF}, + {"ExactLimit", "hello", 5, 5, nil, EOF}, + {"OverLimit", "hello world", 5, 5, nil, EOF}, + {"ZeroLimit", "hello", 0, 0, EOF, EOF}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != tt.want1N || err != tt.want1E { + t.Errorf("first Read() = (%d, %v), want (%d, %v)", n, err, tt.want1N, tt.want1E) + } + + n, err = lr.Read(buf) + if n != 0 || err != tt.want2E { + t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.want2E) + } + }) + } + }) + + t.Run("WithCustomErr", func(t *testing.T) { + tests := []struct { + name string + data string + limit int64 + err error + wantFirst string + wantErr1 error + wantErr2 error + }{ + {"ExactLimit", "hello", 5, errLimit, "hello", nil, EOF}, + {"OverLimit", "hello world", 5, errLimit, "hello", nil, errLimit}, + {"UnderLimit", "hi", 5, errLimit, "hi", nil, EOF}, + {"ZeroLimitEmpty", "", 0, errLimit, "", EOF, EOF}, + {"ZeroLimitNonEmpty", "hello", 0, errLimit, "", errLimit, errLimit}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit, Err: tt.err} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != len(tt.wantFirst) || string(buf[:n]) != tt.wantFirst || err != tt.wantErr1 { + t.Errorf("first Read() = (%d, %q, %v), want (%d, %q, %v)", n, buf[:n], err, len(tt.wantFirst), tt.wantFirst, tt.wantErr1) + } + + n, err = lr.Read(buf) + if n != 0 || err != tt.wantErr2 { + t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.wantErr2) + } + }) + } + }) + + t.Run("CustomErrPersists", func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: errLimit} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || err != nil || string(buf[:5]) != "hello" { + t.Errorf("Read() = (%d, %v, %q), want (5, nil, \"hello\")", n, err, buf[:5]) + } + + n, err = lr.Read(buf) + if n != 0 || err != errLimit { + t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err) + } + + n, err = lr.Read(buf) + if n != 0 || err != errLimit { + t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err) + } + }) + + t.Run("ErrEOF", func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: EOF} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || err != nil { + t.Errorf("Read() = (%d, %v), want (5, nil)", n, err) + } + + n, err = lr.Read(buf) + if n != 0 || err != EOF { + t.Errorf("Read() = (%d, %v), want (0, EOF)", n, err) + } + }) + + t.Run("NoSideEffects", func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader("hello"), N: 5, Err: errLimit} + buf := make([]byte, 0) + + for i := 0; i < 3; i++ { + n, err := lr.Read(buf) + if n != 0 || err != nil { + t.Errorf("zero-length read #%d = (%d, %v), want (0, nil)", i+1, n, err) + } + if lr.N != 5 { + t.Errorf("N after zero-length read #%d = %d, want 5", i+1, lr.N) + } + } + + buf = make([]byte, 10) + n, err := lr.Read(buf) + if n != 5 || string(buf[:5]) != "hello" || err != nil { + t.Errorf("normal Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err) + } + }) +} + +type errorReader struct { + data []byte + pos int + err error +} + +func (r *errorReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, r.err + } + n := copy(p, r.data[r.pos:]) + r.pos += n + return n, nil +} + +func TestLimitedReaderErrors(t *testing.T) { + t.Run("UnderlyingError", func(t *testing.T) { + underlyingErr := errors.New("boom") + lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: underlyingErr}, N: 10} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || string(buf[:5]) != "hello" || err != nil { + t.Errorf("first Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err) + } + + n, err = lr.Read(buf) + if n != 0 || err != underlyingErr { + t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, underlyingErr) + } + }) + + t.Run("SentinelMasksProbeError", func(t *testing.T) { + probeErr := errors.New("probe failed") + lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: probeErr}, N: 5, Err: errLimit} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || string(buf[:5]) != "hello" || err != nil { + t.Errorf("first Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err) + } + + n, err = lr.Read(buf) + if n != 0 || err != errLimit { + t.Errorf("second Read() = (%d, %v), want (0, errLimit)", n, err) + } + }) +} + +func TestLimitedReaderCopy(t *testing.T) { + tests := []struct { + name string + input string + limit int64 + wantN int64 + wantErr error + }{ + {"Exact", "hello", 5, 5, nil}, + {"Under", "hi", 5, 2, nil}, + {"Over", "hello world", 5, 5, errLimit}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader(tt.input), N: tt.limit, Err: errLimit} + var dst Buffer + n, err := Copy(&dst, lr) + if n != tt.wantN || err != tt.wantErr { + t.Errorf("Copy() = (%d, %v), want (%d, %v)", n, err, tt.wantN, tt.wantErr) + } + }) + } +} From ad4322601d12637b114cee4bdd07e6555b9b2291 Mon Sep 17 00:00:00 2001 From: "Sergei G." Date: Tue, 4 Nov 2025 10:28:31 +0400 Subject: [PATCH 2/5] io: add doc/next file and remove errors.Is from example --- doc/next/6-stdlib/99-minor/io/51115.md | 4 ++++ src/io/example_test.go | 10 ++++------ src/io/io.go | 6 ++---- 3 files changed, 10 insertions(+), 10 deletions(-) create mode 100644 doc/next/6-stdlib/99-minor/io/51115.md diff --git a/doc/next/6-stdlib/99-minor/io/51115.md b/doc/next/6-stdlib/99-minor/io/51115.md new file mode 100644 index 00000000000000..ec71f23e9ed3a6 --- /dev/null +++ b/doc/next/6-stdlib/99-minor/io/51115.md @@ -0,0 +1,4 @@ + +The new [LimitedReader.Err] field allows returning a custom error +when the read limit is exceeded. When nil (the default), [LimitedReader.Read] +returns [EOF] when the limit is reached, maintaining backward compatibility. diff --git a/src/io/example_test.go b/src/io/example_test.go index f24debcac6adee..5fca1d0b4e3f4f 100644 --- a/src/io/example_test.go +++ b/src/io/example_test.go @@ -132,17 +132,15 @@ func ExampleLimitedReader_Err() { if err != nil { log.Fatal(err) } - fmt.Printf("read %d bytes: %q\n", n, buf[:n]) + fmt.Printf("%d; %q\n", n, buf[:n]) // try to read more and get the custom error n, err = lr.Read(buf) - if errors.Is(err, sentinel) { - fmt.Println("error:", err) - } + fmt.Printf("%d; error: %v\n", n, err) // Output: - // read 4 bytes: "some" - // error: read limit reached + // 4; "some" + // 0; error: read limit reached } func ExampleMultiReader() { diff --git a/src/io/io.go b/src/io/io.go index 54aac8c846e77e..8a5e794cd68f01 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -556,10 +556,8 @@ func (s *SectionReader) Read(p []byte) (n int, err error) { return } -var ( - errWhence = errors.New("Seek: invalid whence") - errOffset = errors.New("Seek: invalid offset") -) +var errWhence = errors.New("Seek: invalid whence") +var errOffset = errors.New("Seek: invalid offset") func (s *SectionReader) Seek(offset int64, whence int) (int64, error) { switch whence { From 938c0989f855be506d28c85a003a9ca0587cf68a Mon Sep 17 00:00:00 2001 From: "Sergei G." Date: Tue, 4 Nov 2025 11:34:20 +0400 Subject: [PATCH 3/5] io: remove unnecessary EOF checks --- src/io/io.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/io/io.go b/src/io/io.go index 8a5e794cd68f01..59f9a265ad688e 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -467,7 +467,7 @@ func LimitReader(r Reader, n int64) Reader { return &LimitedReader{R: r, N: n} } // // Negative values of N mean that the limit has been exceeded. // Read returns Err when more than N bytes are read from R. -// If Err is nil or EOF, Read returns EOF instead. +// If Err is nil, Read returns EOF. type LimitedReader struct { R Reader // underlying reader N int64 // max bytes remaining @@ -492,7 +492,7 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) { } if l.N < 0 { - if l.N == -1 && l.Err != nil && l.Err != EOF { + if l.N == -1 && l.Err != nil { return 0, l.Err // limit was exceeded } return 0, EOF // stream was exactly N bytes, or already past limit @@ -500,7 +500,7 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) { // At limit (N == 0) - need to determine if stream has more data - if l.Err == nil || l.Err == EOF { + if l.Err == nil { return 0, EOF } @@ -511,7 +511,7 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) { // a byte from R, so we cache the result in N to avoid re-probing. var probe [1]byte probeN, probeErr := l.R.Read(probe[:]) - if probeN > 0 || (probeErr != nil && probeErr != EOF) { + if probeN > 0 || probeErr != EOF { l.N = -1 // more data available, limit exceeded return 0, l.Err } From caab4051e42e230e3b4f9bbe4d1305263644bdda Mon Sep 17 00:00:00 2001 From: "Sergei G." Date: Tue, 4 Nov 2025 12:14:25 +0400 Subject: [PATCH 4/5] io: use consts in LimitedReader and clean up comments --- src/io/io.go | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/io/io.go b/src/io/io.go index 59f9a265ad688e..e02888a7d4f178 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -471,16 +471,17 @@ func LimitReader(r Reader, n int64) Reader { return &LimitedReader{R: r, N: n} } type LimitedReader struct { R Reader // underlying reader N int64 // max bytes remaining - Err error // error to return when limit is exceeded; defaults to EOF if nil + Err error // error to return when limit is exceeded } func (l *LimitedReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { + // We use negative l.N values to signal that we've exceeded the limit and cached the result. + const sentinelExceeded = -1 // Probed and found more data available + const sentinelExactMatch = -2 // Probed and found EOF (exactly N bytes) + + if len(p) == 0 && l.N <= 0 { return 0, nil } - // We use negative l.N values to signal that we've exceeded the limit and cached the result: - // -1 means more data is available - // -2 means hit EOF exactly if l.N > 0 { if int64(len(p)) > l.N { @@ -491,11 +492,13 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) { return } + // Sentinel if l.N < 0 { - if l.N == -1 && l.Err != nil { - return 0, l.Err // limit was exceeded + if l.N == sentinelExceeded && l.Err != nil { + return 0, l.Err } - return 0, EOF // stream was exactly N bytes, or already past limit + + return 0, EOF } // At limit (N == 0) - need to determine if stream has more data @@ -504,18 +507,17 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) { return 0, EOF } - // Probe with one byte to distinguish two cases: - // - Stream had exactly N bytes -> return EOF - // - Stream has more than N bytes -> return custom error + // Probe with one byte to distinguish two sentinels. // We can't tell without reading ahead. This probe permanently consumes // a byte from R, so we cache the result in N to avoid re-probing. var probe [1]byte probeN, probeErr := l.R.Read(probe[:]) if probeN > 0 || probeErr != EOF { - l.N = -1 // more data available, limit exceeded + l.N = sentinelExceeded return 0, l.Err } - l.N = -2 // hit EOF, stream was exactly N bytes + + l.N = sentinelExactMatch return 0, EOF } From 79886e9ee85ab68449d67f66f9ac3718725f4b4f Mon Sep 17 00:00:00 2001 From: "Sergei G." Date: Wed, 5 Nov 2025 08:28:12 +0400 Subject: [PATCH 5/5] io: use raw string literals in tests --- src/io/io_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/io/io_test.go b/src/io/io_test.go index 8d517f6fc08119..53c16ede9c9e83 100644 --- a/src/io/io_test.go +++ b/src/io/io_test.go @@ -780,7 +780,7 @@ func TestLimitedReader(t *testing.T) { n, err := lr.Read(buf) if n != 5 || err != nil || string(buf[:5]) != "hello" { - t.Errorf("Read() = (%d, %v, %q), want (5, nil, \"hello\")", n, err, buf[:5]) + t.Errorf(`Read() = (%d, %v, %q), want (5, nil, "hello")`, n, err, buf[:5]) } n, err = lr.Read(buf) @@ -826,7 +826,7 @@ func TestLimitedReader(t *testing.T) { buf = make([]byte, 10) n, err := lr.Read(buf) if n != 5 || string(buf[:5]) != "hello" || err != nil { - t.Errorf("normal Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err) + t.Errorf(`normal Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err) } }) } @@ -854,7 +854,7 @@ func TestLimitedReaderErrors(t *testing.T) { n, err := lr.Read(buf) if n != 5 || string(buf[:5]) != "hello" || err != nil { - t.Errorf("first Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err) + t.Errorf(`first Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err) } n, err = lr.Read(buf) @@ -870,7 +870,7 @@ func TestLimitedReaderErrors(t *testing.T) { n, err := lr.Read(buf) if n != 5 || string(buf[:5]) != "hello" || err != nil { - t.Errorf("first Read() = (%d, %q, %v), want (5, \"hello\", nil)", n, buf[:5], err) + t.Errorf(`first Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err) } n, err = lr.Read(buf)