119 lines
2.8 KiB
Go
119 lines
2.8 KiB
Go
|
package squirrel
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
|
||
|
"github.com/lann/builder"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
builder.Register(CaseBuilder{}, caseData{})
|
||
|
}
|
||
|
|
||
|
// sqlizerBuffer is a helper that allows to write many Sqlizers one by one
|
||
|
// without constant checks for errors that may come from Sqlizer
|
||
|
type sqlizerBuffer struct {
|
||
|
bytes.Buffer
|
||
|
args []interface{}
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
// WriteSql converts Sqlizer to SQL strings and writes it to buffer
|
||
|
func (b *sqlizerBuffer) WriteSql(item Sqlizer) {
|
||
|
if b.err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var str string
|
||
|
var args []interface{}
|
||
|
str, args, b.err = item.ToSql()
|
||
|
|
||
|
if b.err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
b.WriteString(str)
|
||
|
b.WriteByte(' ')
|
||
|
b.args = append(b.args, args...)
|
||
|
}
|
||
|
|
||
|
func (b *sqlizerBuffer) ToSql() (string, []interface{}, error) {
|
||
|
return b.String(), b.args, b.err
|
||
|
}
|
||
|
|
||
|
// whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression
|
||
|
type whenPart struct {
|
||
|
when Sqlizer
|
||
|
then Sqlizer
|
||
|
}
|
||
|
|
||
|
func newWhenPart(when interface{}, then interface{}) whenPart {
|
||
|
return whenPart{newPart(when), newPart(then)}
|
||
|
}
|
||
|
|
||
|
// caseData holds all the data required to build a CASE SQL construct
|
||
|
type caseData struct {
|
||
|
What Sqlizer
|
||
|
WhenParts []whenPart
|
||
|
Else Sqlizer
|
||
|
}
|
||
|
|
||
|
// ToSql implements Sqlizer
|
||
|
func (d *caseData) ToSql() (sqlStr string, args []interface{}, err error) {
|
||
|
if len(d.WhenParts) == 0 {
|
||
|
err = errors.New("case expression must contain at lease one WHEN clause")
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
sql := sqlizerBuffer{}
|
||
|
|
||
|
sql.WriteString("CASE ")
|
||
|
if d.What != nil {
|
||
|
sql.WriteSql(d.What)
|
||
|
}
|
||
|
|
||
|
for _, p := range d.WhenParts {
|
||
|
sql.WriteString("WHEN ")
|
||
|
sql.WriteSql(p.when)
|
||
|
sql.WriteString("THEN ")
|
||
|
sql.WriteSql(p.then)
|
||
|
}
|
||
|
|
||
|
if d.Else != nil {
|
||
|
sql.WriteString("ELSE ")
|
||
|
sql.WriteSql(d.Else)
|
||
|
}
|
||
|
|
||
|
sql.WriteString("END")
|
||
|
|
||
|
return sql.ToSql()
|
||
|
}
|
||
|
|
||
|
// CaseBuilder builds SQL CASE construct which could be used as parts of queries.
|
||
|
type CaseBuilder builder.Builder
|
||
|
|
||
|
// ToSql builds the query into a SQL string and bound args.
|
||
|
func (b CaseBuilder) ToSql() (string, []interface{}, error) {
|
||
|
data := builder.GetStruct(b).(caseData)
|
||
|
return data.ToSql()
|
||
|
}
|
||
|
|
||
|
// what sets optional value for CASE construct "CASE [value] ..."
|
||
|
func (b CaseBuilder) what(expr interface{}) CaseBuilder {
|
||
|
return builder.Set(b, "What", newPart(expr)).(CaseBuilder)
|
||
|
}
|
||
|
|
||
|
// When adds "WHEN ... THEN ..." part to CASE construct
|
||
|
func (b CaseBuilder) When(when interface{}, then interface{}) CaseBuilder {
|
||
|
// TODO: performance hint: replace slice of WhenPart with just slice of parts
|
||
|
// where even indices of the slice belong to "when"s and odd indices belong to "then"s
|
||
|
return builder.Append(b, "WhenParts", newWhenPart(when, then)).(CaseBuilder)
|
||
|
}
|
||
|
|
||
|
// What sets optional "ELSE ..." part for CASE construct
|
||
|
func (b CaseBuilder) Else(expr interface{}) CaseBuilder {
|
||
|
return builder.Set(b, "Else", newPart(expr)).(CaseBuilder)
|
||
|
}
|