diff --git a/go.mod b/go.mod index a08572c..7d1e982 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.0 require ( github.com/Masterminds/squirrel v1.5.4 github.com/georgysavva/scany/v2 v2.1.4 + github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 github.com/jackc/pgx/v5 v5.9.0 github.com/stretchr/testify v1.11.1 ) diff --git a/go.sum b/go.sum index b791cab..2df5734 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/georgysavva/scany/v2 v2.1.4 h1:nrzHEJ4oQVRoiKmocRqA1IyGOmM/GQOEsg9UjM github.com/georgysavva/scany/v2 v2.1.4/go.mod h1:fqp9yHZzM/PFVa3/rYEC57VmDx+KDch0LoqrJzkvtos= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= +github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 h1:D/V0gu4zQ3cL2WKeVNVM4r2gLxGGf6McLwgXzRTo2RQ= +github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/pgerr/pgerr.go b/pgerr/pgerr.go new file mode 100644 index 0000000..d05b0f3 --- /dev/null +++ b/pgerr/pgerr.go @@ -0,0 +1,48 @@ +// Package pgerr provides PostgreSQL error inspection helpers built on +// pgconn.PgError. SQLSTATE code constants are re-exported from +// github.com/jackc/pgerrcode — import that package directly if you need +// the full table of codes or class-membership helpers. +package pgerr + +import ( + "database/sql" + "errors" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// IsErrorNoRows reports whether err is or wraps pgx.ErrNoRows or sql.ErrNoRows. +func IsErrorNoRows(err error) bool { + return errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) +} + +// IsUniqueViolation reports whether err is a Postgres unique constraint +// violation (SQLSTATE 23505). When true, the returned string is the +// constraint name reported by the server (empty if the driver did not +// surface it). +func IsUniqueViolation(err error) (string, bool) { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) || pgErr.Code != pgerrcode.UniqueViolation { + return "", false + } + return pgErr.ConstraintName, true +} + +// IsFatal reports whether err is a Postgres error that should be treated as +// fatal for retry/alerting purposes. SQLSTATE classes 00 (success), 01 +// (warning), 02 (no data), and 23 (integrity constraint violation) are +// non-fatal; every other class is fatal. Non-PgError errors return false +// (no opinion). +func IsFatal(err error) bool { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) || len(pgErr.Code) < 2 { + return false + } + switch pgErr.Code[:2] { + case "00", "01", "02", "23": + return false + } + return true +} diff --git a/pgerr/pgerr_test.go b/pgerr/pgerr_test.go new file mode 100644 index 0000000..c3299b9 --- /dev/null +++ b/pgerr/pgerr_test.go @@ -0,0 +1,107 @@ +package pgerr_test + +import ( + "database/sql" + "errors" + "fmt" + "testing" + + "github.com/goware/pgkit/v2/pgerr" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestIsErrorNoRows(t *testing.T) { + t.Run("pgx.ErrNoRows", func(t *testing.T) { + assert.True(t, pgerr.IsErrorNoRows(pgx.ErrNoRows)) + }) + t.Run("sql.ErrNoRows", func(t *testing.T) { + assert.True(t, pgerr.IsErrorNoRows(sql.ErrNoRows)) + }) + t.Run("wrapped pgx.ErrNoRows", func(t *testing.T) { + assert.True(t, pgerr.IsErrorNoRows(fmt.Errorf("scan failed: %w", pgx.ErrNoRows))) + }) + t.Run("nil", func(t *testing.T) { + assert.False(t, pgerr.IsErrorNoRows(nil)) + }) + t.Run("unrelated error", func(t *testing.T) { + assert.False(t, pgerr.IsErrorNoRows(errors.New("boom"))) + }) +} + +func TestIsUniqueViolation(t *testing.T) { + t.Run("unique violation with constraint name", func(t *testing.T) { + err := &pgconn.PgError{Code: pgerrcode.UniqueViolation, ConstraintName: "uq_users_email"} + name, ok := pgerr.IsUniqueViolation(err) + assert.True(t, ok) + assert.Equal(t, "uq_users_email", name) + }) + t.Run("wrapped unique violation", func(t *testing.T) { + pgErr := &pgconn.PgError{Code: pgerrcode.UniqueViolation, ConstraintName: "uq_users_email"} + wrapped := fmt.Errorf("insert failed: %w", pgErr) + name, ok := pgerr.IsUniqueViolation(wrapped) + assert.True(t, ok) + assert.Equal(t, "uq_users_email", name) + }) + t.Run("unique violation without constraint name", func(t *testing.T) { + err := &pgconn.PgError{Code: pgerrcode.UniqueViolation} + name, ok := pgerr.IsUniqueViolation(err) + assert.True(t, ok) + assert.Equal(t, "", name) + }) + t.Run("different pg error code", func(t *testing.T) { + err := &pgconn.PgError{Code: pgerrcode.ForeignKeyViolation, ConstraintName: "fk_x"} + name, ok := pgerr.IsUniqueViolation(err) + assert.False(t, ok) + assert.Equal(t, "", name) + }) + t.Run("non pg error", func(t *testing.T) { + name, ok := pgerr.IsUniqueViolation(errors.New("boom")) + assert.False(t, ok) + assert.Equal(t, "", name) + }) + t.Run("nil", func(t *testing.T) { + name, ok := pgerr.IsUniqueViolation(nil) + assert.False(t, ok) + assert.Equal(t, "", name) + }) +} + +func TestIsFatal(t *testing.T) { + t.Run("class 00 successful completion is non-fatal", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.SuccessfulCompletion})) + }) + t.Run("class 01 warning is non-fatal", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.Warning})) + }) + t.Run("class 02 no data is non-fatal", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.NoData})) + }) + t.Run("class 23 unique violation is non-fatal", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.UniqueViolation})) + }) + t.Run("class 08 connection exception is fatal", func(t *testing.T) { + assert.True(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.ConnectionException})) + }) + t.Run("class 57 operator intervention is fatal", func(t *testing.T) { + assert.True(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.QueryCanceled})) + }) + t.Run("class XX internal error is fatal", func(t *testing.T) { + assert.True(t, pgerr.IsFatal(&pgconn.PgError{Code: pgerrcode.InternalError})) + }) + t.Run("wrapped fatal pg error", func(t *testing.T) { + pgErr := &pgconn.PgError{Code: pgerrcode.ConnectionFailure} + assert.True(t, pgerr.IsFatal(fmt.Errorf("query: %w", pgErr))) + }) + t.Run("non pg error", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(errors.New("boom"))) + }) + t.Run("nil", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(nil)) + }) + t.Run("short code is non-fatal", func(t *testing.T) { + assert.False(t, pgerr.IsFatal(&pgconn.PgError{Code: "X"})) + }) +} diff --git a/pgkit.go b/pgkit.go index 1100ba3..841d0d5 100644 --- a/pgkit.go +++ b/pgkit.go @@ -34,6 +34,17 @@ func (d *DB) TxQueryFromContext(ctx context.Context) *Querier { return d.TxQuery(tx) } +// InTx returns a new *DB that shares Conn and SQL with d but routes queries +// through tx. Use it when a function takes *DB and you want it to participate +// in a transaction the caller controls. +func (d *DB) InTx(tx pgx.Tx) *DB { + return &DB{ + Conn: d.Conn, + SQL: d.SQL, + Query: d.TxQuery(tx), + } +} + type Config struct { Database string `toml:"database"` Host string `toml:"host"` diff --git a/pgkit_test.go b/pgkit_test.go new file mode 100644 index 0000000..7aadb50 --- /dev/null +++ b/pgkit_test.go @@ -0,0 +1,43 @@ +package pgkit_test + +import ( + "testing" + + "github.com/goware/pgkit/v2" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubTx satisfies pgx.Tx via interface embedding; method calls would panic, +// but InTx only stores the reference, so the stub is enough to verify wiring. +type stubTx struct{ pgx.Tx } + +func TestDBInTx(t *testing.T) { + t.Run("wires tx into a fresh Querier", func(t *testing.T) { + db := &pgkit.DB{ + SQL: &pgkit.StatementBuilder{}, + Query: &pgkit.Querier{}, + } + tx := &stubTx{} + + inTx := db.InTx(tx) + + require.NotNil(t, inTx) + assert.Same(t, db.SQL, inTx.SQL, "SQL should be shared with parent") + assert.NotSame(t, db.Query, inTx.Query, "Query should be a fresh tx-scoped Querier") + assert.Equal(t, pgx.Tx(tx), inTx.Query.Tx, "Querier.Tx should hold the input tx") + }) + t.Run("parent Query is untouched", func(t *testing.T) { + parentQuery := &pgkit.Querier{} + db := &pgkit.DB{ + SQL: &pgkit.StatementBuilder{}, + Query: parentQuery, + } + + _ = db.InTx(&stubTx{}) + + assert.Same(t, parentQuery, db.Query, "parent DB.Query reference must not change") + assert.Nil(t, parentQuery.Tx, "parent Querier.Tx must stay nil") + }) +}