diff --git a/api/next/51115.txt b/api/next/51115.txt new file mode 100644 index 00000000000000..0ce24b4ed04e71 --- /dev/null +++ b/api/next/51115.txt @@ -0,0 +1 @@ +pkg io, type LimitedReader struct, Err error #51115 diff --git a/doc/next/6-stdlib/99-minor/io/51115.md b/doc/next/6-stdlib/99-minor/io/51115.md new file mode 100644 index 00000000000000..ec71f23e9ed3a6 --- /dev/null +++ b/doc/next/6-stdlib/99-minor/io/51115.md @@ -0,0 +1,4 @@ + +The new [LimitedReader.Err] field allows returning a custom error +when the read limit is exceeded. When nil (the default), [LimitedReader.Read] +returns [EOF] when the limit is reached, maintaining backward compatibility. diff --git a/src/io/example_test.go b/src/io/example_test.go index 818020e9dec6dd..5fca1d0b4e3f4f 100644 --- a/src/io/example_test.go +++ b/src/io/example_test.go @@ -5,6 +5,7 @@ package io_test import ( + "errors" "fmt" "io" "log" @@ -121,6 +122,27 @@ func ExampleLimitReader() { // some } +func ExampleLimitedReader_Err() { + r := strings.NewReader("some io.Reader stream to be read\n") + sentinel := errors.New("read limit reached") + lr := &io.LimitedReader{R: r, N: 4, Err: sentinel} + + buf := make([]byte, 10) + n, err := lr.Read(buf) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%d; %q\n", n, buf[:n]) + + // try to read more and get the custom error + n, err = lr.Read(buf) + fmt.Printf("%d; error: %v\n", n, err) + + // Output: + // 4; "some" + // 0; error: read limit reached +} + func ExampleMultiReader() { r1 := strings.NewReader("first reader ") r2 := strings.NewReader("second reader ") diff --git a/src/io/io.go b/src/io/io.go index 00edcde763a55a..e02888a7d4f178 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -457,28 +457,68 @@ func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) { // LimitReader returns a Reader that reads from r // but stops with EOF after n bytes. -// The underlying implementation is a *LimitedReader. -func LimitReader(r Reader, n int64) Reader { return &LimitedReader{r, n} } +// To return a custom error when the limit is reached, construct +// a *LimitedReader directly with the desired Err field. +func LimitReader(r Reader, n int64) Reader { return &LimitedReader{R: r, N: n} } // A LimitedReader reads from R but limits the amount of // data returned to just N bytes. Each call to Read // updates N to reflect the new amount remaining. -// Read returns EOF when N <= 0 or when the underlying R returns EOF. +// +// Negative values of N mean that the limit has been exceeded. +// Read returns Err when more than N bytes are read from R. +// If Err is nil, Read returns EOF. type LimitedReader struct { - R Reader // underlying reader - N int64 // max bytes remaining + R Reader // underlying reader + N int64 // max bytes remaining + Err error // error to return when limit is exceeded } func (l *LimitedReader) Read(p []byte) (n int, err error) { - if l.N <= 0 { + // We use negative l.N values to signal that we've exceeded the limit and cached the result. + const sentinelExceeded = -1 // Probed and found more data available + const sentinelExactMatch = -2 // Probed and found EOF (exactly N bytes) + + if len(p) == 0 && l.N <= 0 { + return 0, nil + } + + if l.N > 0 { + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return + } + + // Sentinel + if l.N < 0 { + if l.N == sentinelExceeded && l.Err != nil { + return 0, l.Err + } + + return 0, EOF + } + + // At limit (N == 0) - need to determine if stream has more data + + if l.Err == nil { return 0, EOF } - if int64(len(p)) > l.N { - p = p[0:l.N] + + // Probe with one byte to distinguish two sentinels. + // We can't tell without reading ahead. This probe permanently consumes + // a byte from R, so we cache the result in N to avoid re-probing. + var probe [1]byte + probeN, probeErr := l.R.Read(probe[:]) + if probeN > 0 || probeErr != EOF { + l.N = sentinelExceeded + return 0, l.Err } - n, err = l.R.Read(p) - l.N -= int64(n) - return + + l.N = sentinelExactMatch + return 0, EOF } // NewSectionReader returns a [SectionReader] that reads from r diff --git a/src/io/io_test.go b/src/io/io_test.go index 38bec8243e7baa..53c16ede9c9e83 100644 --- a/src/io/io_test.go +++ b/src/io/io_test.go @@ -692,3 +692,215 @@ func TestOffsetWriter_Write(t *testing.T) { checkContent(name, f) }) } + +var errLimit = errors.New("limit exceeded") + +func TestLimitedReader(t *testing.T) { + src := strings.NewReader("abc") + r := LimitReader(src, 5) + lr, ok := r.(*LimitedReader) + if !ok { + t.Fatalf("LimitReader should return *LimitedReader, got %T", r) + } + if lr.R != src || lr.N != 5 || lr.Err != nil { + t.Fatalf("LimitReader() = {R: %v, N: %d, Err: %v}, want {R: %v, N: 5, Err: nil}", lr.R, lr.N, lr.Err, src) + } + + t.Run("WithoutCustomErr", func(t *testing.T) { + tests := []struct { + name string + data string + limit int64 + want1N int + want1E error + want2E error + }{ + {"UnderLimit", "hello", 10, 5, nil, EOF}, + {"ExactLimit", "hello", 5, 5, nil, EOF}, + {"OverLimit", "hello world", 5, 5, nil, EOF}, + {"ZeroLimit", "hello", 0, 0, EOF, EOF}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != tt.want1N || err != tt.want1E { + t.Errorf("first Read() = (%d, %v), want (%d, %v)", n, err, tt.want1N, tt.want1E) + } + + n, err = lr.Read(buf) + if n != 0 || err != tt.want2E { + t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.want2E) + } + }) + } + }) + + t.Run("WithCustomErr", func(t *testing.T) { + tests := []struct { + name string + data string + limit int64 + err error + wantFirst string + wantErr1 error + wantErr2 error + }{ + {"ExactLimit", "hello", 5, errLimit, "hello", nil, EOF}, + {"OverLimit", "hello world", 5, errLimit, "hello", nil, errLimit}, + {"UnderLimit", "hi", 5, errLimit, "hi", nil, EOF}, + {"ZeroLimitEmpty", "", 0, errLimit, "", EOF, EOF}, + {"ZeroLimitNonEmpty", "hello", 0, errLimit, "", errLimit, errLimit}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit, Err: tt.err} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != len(tt.wantFirst) || string(buf[:n]) != tt.wantFirst || err != tt.wantErr1 { + t.Errorf("first Read() = (%d, %q, %v), want (%d, %q, %v)", n, buf[:n], err, len(tt.wantFirst), tt.wantFirst, tt.wantErr1) + } + + n, err = lr.Read(buf) + if n != 0 || err != tt.wantErr2 { + t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.wantErr2) + } + }) + } + }) + + t.Run("CustomErrPersists", func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: errLimit} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || err != nil || string(buf[:5]) != "hello" { + t.Errorf(`Read() = (%d, %v, %q), want (5, nil, "hello")`, n, err, buf[:5]) + } + + n, err = lr.Read(buf) + if n != 0 || err != errLimit { + t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err) + } + + n, err = lr.Read(buf) + if n != 0 || err != errLimit { + t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err) + } + }) + + t.Run("ErrEOF", func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: EOF} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || err != nil { + t.Errorf("Read() = (%d, %v), want (5, nil)", n, err) + } + + n, err = lr.Read(buf) + if n != 0 || err != EOF { + t.Errorf("Read() = (%d, %v), want (0, EOF)", n, err) + } + }) + + t.Run("NoSideEffects", func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader("hello"), N: 5, Err: errLimit} + buf := make([]byte, 0) + + for i := 0; i < 3; i++ { + n, err := lr.Read(buf) + if n != 0 || err != nil { + t.Errorf("zero-length read #%d = (%d, %v), want (0, nil)", i+1, n, err) + } + if lr.N != 5 { + t.Errorf("N after zero-length read #%d = %d, want 5", i+1, lr.N) + } + } + + buf = make([]byte, 10) + n, err := lr.Read(buf) + if n != 5 || string(buf[:5]) != "hello" || err != nil { + t.Errorf(`normal Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err) + } + }) +} + +type errorReader struct { + data []byte + pos int + err error +} + +func (r *errorReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, r.err + } + n := copy(p, r.data[r.pos:]) + r.pos += n + return n, nil +} + +func TestLimitedReaderErrors(t *testing.T) { + t.Run("UnderlyingError", func(t *testing.T) { + underlyingErr := errors.New("boom") + lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: underlyingErr}, N: 10} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || string(buf[:5]) != "hello" || err != nil { + t.Errorf(`first Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err) + } + + n, err = lr.Read(buf) + if n != 0 || err != underlyingErr { + t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, underlyingErr) + } + }) + + t.Run("SentinelMasksProbeError", func(t *testing.T) { + probeErr := errors.New("probe failed") + lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: probeErr}, N: 5, Err: errLimit} + buf := make([]byte, 10) + + n, err := lr.Read(buf) + if n != 5 || string(buf[:5]) != "hello" || err != nil { + t.Errorf(`first Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err) + } + + n, err = lr.Read(buf) + if n != 0 || err != errLimit { + t.Errorf("second Read() = (%d, %v), want (0, errLimit)", n, err) + } + }) +} + +func TestLimitedReaderCopy(t *testing.T) { + tests := []struct { + name string + input string + limit int64 + wantN int64 + wantErr error + }{ + {"Exact", "hello", 5, 5, nil}, + {"Under", "hi", 5, 2, nil}, + {"Over", "hello world", 5, 5, errLimit}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lr := &LimitedReader{R: strings.NewReader(tt.input), N: tt.limit, Err: errLimit} + var dst Buffer + n, err := Copy(&dst, lr) + if n != tt.wantN || err != tt.wantErr { + t.Errorf("Copy() = (%d, %v), want (%d, %v)", n, err, tt.wantN, tt.wantErr) + } + }) + } +}