summaryrefslogtreecommitdiff
path: root/src/pkg/text/template/funcs.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/pkg/text/template/funcs.go')
-rw-r--r--src/pkg/text/template/funcs.go206
1 files changed, 185 insertions, 21 deletions
diff --git a/src/pkg/text/template/funcs.go b/src/pkg/text/template/funcs.go
index 818766364..e85412262 100644
--- a/src/pkg/text/template/funcs.go
+++ b/src/pkg/text/template/funcs.go
@@ -6,6 +6,7 @@ package template
import (
"bytes"
+ "errors"
"fmt"
"io"
"net/url"
@@ -35,6 +36,14 @@ var builtins = FuncMap{
"printf": fmt.Sprintf,
"println": fmt.Sprintln,
"urlquery": URLQueryEscaper,
+
+ // Comparisons
+ "eq": eq, // ==
+ "ge": ge, // >=
+ "gt": gt, // >
+ "le": le, // <=
+ "lt": lt, // <
+ "ne": ne, // !=
}
var builtinFuncs = createValueFuncs(builtins)
@@ -199,7 +208,7 @@ func call(fn interface{}, args ...interface{}) (interface{}, error) {
argv[i] = value
}
result := v.Call(argv)
- if len(result) == 2 {
+ if len(result) == 2 && !result[1].IsNil() {
return result[0].Interface(), result[1].Interface().(error)
}
return result[0].Interface(), nil
@@ -248,6 +257,160 @@ func not(arg interface{}) (truth bool) {
return !truth
}
+// Comparison.
+
+// TODO: Perhaps allow comparison between signed and unsigned integers.
+
+var (
+ errBadComparisonType = errors.New("invalid type for comparison")
+ errBadComparison = errors.New("incompatible types for comparison")
+ errNoComparison = errors.New("missing argument for comparison")
+)
+
+type kind int
+
+const (
+ invalidKind kind = iota
+ boolKind
+ complexKind
+ intKind
+ floatKind
+ integerKind
+ stringKind
+ uintKind
+)
+
+func basicKind(v reflect.Value) (kind, error) {
+ switch v.Kind() {
+ case reflect.Bool:
+ return boolKind, nil
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return intKind, nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return uintKind, nil
+ case reflect.Float32, reflect.Float64:
+ return floatKind, nil
+ case reflect.Complex64, reflect.Complex128:
+ return complexKind, nil
+ case reflect.String:
+ return stringKind, nil
+ }
+ return invalidKind, errBadComparisonType
+}
+
+// eq evaluates the comparison a == b || a == c || ...
+func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
+ v1 := reflect.ValueOf(arg1)
+ k1, err := basicKind(v1)
+ if err != nil {
+ return false, err
+ }
+ if len(arg2) == 0 {
+ return false, errNoComparison
+ }
+ for _, arg := range arg2 {
+ v2 := reflect.ValueOf(arg)
+ k2, err := basicKind(v2)
+ if err != nil {
+ return false, err
+ }
+ if k1 != k2 {
+ return false, errBadComparison
+ }
+ truth := false
+ switch k1 {
+ case boolKind:
+ truth = v1.Bool() == v2.Bool()
+ case complexKind:
+ truth = v1.Complex() == v2.Complex()
+ case floatKind:
+ truth = v1.Float() == v2.Float()
+ case intKind:
+ truth = v1.Int() == v2.Int()
+ case stringKind:
+ truth = v1.String() == v2.String()
+ case uintKind:
+ truth = v1.Uint() == v2.Uint()
+ default:
+ panic("invalid kind")
+ }
+ if truth {
+ return true, nil
+ }
+ }
+ return false, nil
+}
+
+// ne evaluates the comparison a != b.
+func ne(arg1, arg2 interface{}) (bool, error) {
+ // != is the inverse of ==.
+ equal, err := eq(arg1, arg2)
+ return !equal, err
+}
+
+// lt evaluates the comparison a < b.
+func lt(arg1, arg2 interface{}) (bool, error) {
+ v1 := reflect.ValueOf(arg1)
+ k1, err := basicKind(v1)
+ if err != nil {
+ return false, err
+ }
+ v2 := reflect.ValueOf(arg2)
+ k2, err := basicKind(v2)
+ if err != nil {
+ return false, err
+ }
+ if k1 != k2 {
+ return false, errBadComparison
+ }
+ truth := false
+ switch k1 {
+ case boolKind, complexKind:
+ return false, errBadComparisonType
+ case floatKind:
+ truth = v1.Float() < v2.Float()
+ case intKind:
+ truth = v1.Int() < v2.Int()
+ case stringKind:
+ truth = v1.String() < v2.String()
+ case uintKind:
+ truth = v1.Uint() < v2.Uint()
+ default:
+ panic("invalid kind")
+ }
+ return truth, nil
+}
+
+// le evaluates the comparison <= b.
+func le(arg1, arg2 interface{}) (bool, error) {
+ // <= is < or ==.
+ lessThan, err := lt(arg1, arg2)
+ if lessThan || err != nil {
+ return lessThan, err
+ }
+ return eq(arg1, arg2)
+}
+
+// gt evaluates the comparison a > b.
+func gt(arg1, arg2 interface{}) (bool, error) {
+ // > is the inverse of <=.
+ lessOrEqual, err := le(arg1, arg2)
+ if err != nil {
+ return false, err
+ }
+ return !lessOrEqual, nil
+}
+
+// ge evaluates the comparison a >= b.
+func ge(arg1, arg2 interface{}) (bool, error) {
+ // >= is the inverse of <.
+ lessThan, err := lt(arg1, arg2)
+ if err != nil {
+ return false, err
+ }
+ return !lessThan, nil
+}
+
// HTML escaping.
var (
@@ -298,15 +461,7 @@ func HTMLEscapeString(s string) string {
// HTMLEscaper returns the escaped HTML equivalent of the textual
// representation of its arguments.
func HTMLEscaper(args ...interface{}) string {
- ok := false
- var s string
- if len(args) == 1 {
- s, ok = args[0].(string)
- }
- if !ok {
- s = fmt.Sprint(args...)
- }
- return HTMLEscapeString(s)
+ return HTMLEscapeString(evalArgs(args))
}
// JavaScript escaping.
@@ -391,26 +546,35 @@ func jsIsSpecial(r rune) bool {
// JSEscaper returns the escaped JavaScript equivalent of the textual
// representation of its arguments.
func JSEscaper(args ...interface{}) string {
- ok := false
- var s string
- if len(args) == 1 {
- s, ok = args[0].(string)
- }
- if !ok {
- s = fmt.Sprint(args...)
- }
- return JSEscapeString(s)
+ return JSEscapeString(evalArgs(args))
}
// URLQueryEscaper returns the escaped value of the textual representation of
// its arguments in a form suitable for embedding in a URL query.
func URLQueryEscaper(args ...interface{}) string {
- s, ok := "", false
+ return url.QueryEscape(evalArgs(args))
+}
+
+// evalArgs formats the list of arguments into a string. It is therefore equivalent to
+// fmt.Sprint(args...)
+// except that each argument is indirected (if a pointer), as required,
+// using the same rules as the default string evaluation during template
+// execution.
+func evalArgs(args []interface{}) string {
+ ok := false
+ var s string
+ // Fast path for simple common case.
if len(args) == 1 {
s, ok = args[0].(string)
}
if !ok {
+ for i, arg := range args {
+ a, ok := printableValue(reflect.ValueOf(arg))
+ if ok {
+ args[i] = a
+ } // else left fmt do its thing
+ }
s = fmt.Sprint(args...)
}
- return url.QueryEscape(s)
+ return s
}