複数のテーブルからのいくつかの操作を組み合わせたトランザクションを実行する必要があります。今回は、dbトランザクション実装をシンプルに導入する方法について紹介します。この記事では、sqlcというライブラリを使います。golangでよく使われるDBアクセスライブラリについては、こちらの記事を参照下さい。

トランザクションとは

複数の操作で構成される単一の作業単位のことをトランザクションと言います。

DBトランザクションとは

データベーストランザクションについてwikipediaで紹介されているものを紹介します。

トランザクション処理では、データベースの個々の操作が自動的に1つに連結され、不可分のトランザクションとされることがある。トランザクション処理システムは、1つのトランザクション内の全操作がエラー無しに成功するか、全操作が実行されないことを保証する。一部の操作が成功し、他の操作でエラーが発生した場合、トランザクション処理システムはそのトランザクションの「全」操作を「ロールバック; roll back」し、そのトランザクションによる痕跡を消去してデータベースを一貫した状態(そのトランザクションを開始する前の状態)にリストアする。あるトランザクションの全操作が完了した場合、そのトランザクションはシステムによって「コミット; commit」され、データベースに加えられた更新内容が恒久的なものとなる。コミットされたトランザクションがロールバックされることはない。

出所:Wikipedia

簡単に書くと、複数のデータベース操作で構成される単一の作業単位です。

トランザクションの例

「トランザクション」の例と言うと、やはり銀行の口座振込みの例がいいと思います。ここでは、Aさんの口座(残高:15,000円)から5,000円、Bさんの口座(残高:20,000円)へ振り込むケースを考えます。こちらの操作は、2つの処理から成り立っています。

  1. Aさんの口座から5,000円分差し引き、残高15,000 – 5,000 = 10,000円にする。
  2. Bさんの口座へ5,000円分プラスし、残高20,000 + 5,000 = 25,000円にする。

それぞれが別々の処理の場合、1.の処理が成功して、2.で失敗したというケースを考えると、Aさんの口座から引き出された5,000円がAさんの口座からなくなり、行方不明の状態になります。2.で処理失敗した場合には、Aさんの口座に対して行った操作も取り消さないといけません(この処理は、ロールバック)。このように処理単位として整合性をとる必要があるものをトランザクションといいます。この例からもトランザクションが必要であることがわかると思います。

DBトランザクションを実装する

DBトランザクションを実装します。今回は、sqlcを例に汎用的なトランザクション処理を実装する方法について紹介します。sqlcを使って、sqlから自動生成(sqlc generate)すると下記のようなコードが自動生成されます。DBTxインタフェースの作り形がポイントになってます。

 
// Code generated by sqlc. DO NOT EDIT.

package db

import (
	"context"
	"database/sql"
)

type DBTX interface {
	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
	PrepareContext(context.Context, string) (*sql.Stmt, error)
	QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
	QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

func New(db DBTX) *Queries {
	return &Queries{db: db}
}

type Queries struct {
	db DBTX
}

func (q *Queries) WithTx(tx *sql.Tx) *Queries {
	return &Queries{
		db: tx,
	}
}

goでトランザクションを開始する場合には、BeginTx()を用います。また、こちらのBeginTx()の戻り値は、Tx構造体が返却されますが、こちらは、先程のDBTxが提供するinterfaceを実装しています。execTx()というトランザクション処理するメソッドを作成して、そのcallback関数として、処理をもらうようにすることでトランザクション処理が散らばることを防ぎます。

 
package db

import (
	"context"
	"database/sql"
	"fmt"
)

type Store struct {
	*Queries
	db *sql.DB
}

func NewStore(db *sql.DB) *Store {
	return &Store{
		db:      db,
		Queries: New(db),
	}
}

func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
	// read-committedでトランザクションを開始する
	tx, err := store.db.BeginTx(ctx, &sql.TxOptions{})
	if err != nil {
		return err
	}

	q := New(tx)
	err = fn(q)
	if err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
		}
		return err
	}

	return tx.Commit()
}

sqlcで自動生成されるコードには、sql操作から下記のようなインタフェースが実装されます。このインタフェース自体もDBトランザクション処理の中から利用することが可能です。

 
// Code generated by sqlc. DO NOT EDIT.

package db

import (
	"context"
	"database/sql"
)

type Querier interface {
	CreateOrder(ctx context.Context, userID int32) (sql.NullInt32, error)
	CreateOrderItem(ctx context.Context, arg CreateOrderItemParams) (OrdersItem, error)
	GetOrder(ctx context.Context, userID int32) ([]GetOrderRow, error)
}

var _ Querier = (*Queries)(nil)

実際に使うコードは、下記のようにします。execTx()関数自体は、このパッケージのみに公開するようにして、トランザクションを含む関数をxxxTx() のような形でパッケージ外部に公開してあげます。

 
type CreateOrderTxParams struct {
	UserID   int32    `json:"user_id"`
	Products []string `json:"products"`
}

type CreateOrderTxResult struct {
	OrderID int32 `json:"order_id"`
}

func (store *Store) AddOrderTx(ctx context.Context, arg CreateOrderTxParams) (CreateOrderTxResult, error) {
	var result CreateOrderTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		dbInt, err := q.CreateOrder(ctx, arg.UserID)
		if err != nil {
			return fmt.Errorf("error create order %w", err)
		}

		if !dbInt.Valid {
			return fmt.Errorf("error create cart item %w", err)
		}

		// orderIDを取得
		orderID := dbInt.Int32

		// order詳細を生成する
		for _, productID := range arg.Products {
			params := CreateOrderItemParams{
				OrderID:   orderID,
				ProductID: productID,
			}

			_, err := q.CreateOrderItem(ctx, params)
			if err != nil {
				return fmt.Errorf("error create order item %w", err)
			}
		}
		// 戻り値としてorderIDをセットする
		result.OrderID = orderID
		return nil
	})

	return result, err
}

今回は、sqlcで生成したDBアクセス用コードにDBトランザクションを追加する方法について紹介しました。この方法ですと自動コード部分に手をいれずにDBトランザクションを使うコードをうまくカプセル化でき、他パッケージにDBトランザクションコードが散らばることがありません。DBTxインタフェースに依存するようにコードを書けば、他のDBアクセスライブラリでも同様のことができると思います。お疲れ様でした。