-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsolver.go
126 lines (107 loc) · 2.8 KB
/
solver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package shapes
import (
"fmt"
"reflect"
"github.com/pkg/errors"
)
// solver.go implements the constraint solvers
// there are two kinds of constraints to solve: variable constraints and SubjectTo constraints.
// exprConstraint says that A must be equal to B
type exprConstraint struct {
a, b Expr
}
func (c exprConstraint) apply(ss substitutions) substitutable {
return exprConstraint{
a: c.a.apply(ss).(Expr),
b: c.b.apply(ss).(Expr),
}
}
func (c exprConstraint) freevars() varset { return exprtup(c).freevars() }
func (c exprConstraint) Format(f fmt.State, r rune) { fmt.Fprintf(f, "{%v = %v}", c.a, c.b) }
type constraints []exprConstraint
func (cs constraints) apply(ss substitutions) substitutable {
if len(ss) == 0 || len(cs) == 0 {
return cs
}
for i := range cs {
cs[i] = cs[i].apply(ss).(exprConstraint)
}
return cs
}
func (cs constraints) freevars() (retVal varset) {
for i := range cs {
retVal = append(retVal, cs[i].freevars()...)
}
return unique(retVal)
}
func solve(cs constraints, subs substitutions) (newSubs substitutions, err error) {
switch len(cs) {
case 0:
return subs, nil
default:
var ss substitutions
c := cs[0]
if ss, err = unify(c.a.(substitutableExpr), c.b.(substitutableExpr)); err != nil {
return nil, err
}
newSubs = compose(ss, subs)
cs2 := cs[1:].apply(newSubs).(constraints)
return solve(cs2, newSubs)
}
}
func unify(a, b substitutableExpr) (ss substitutions, err error) {
switch at := a.(type) {
case Var:
return bind(at, b)
default:
if eq(a, b) {
return nil, nil
}
if v, ok := b.(Var); ok {
return bind(v, a)
}
aExprs := a.subExprs()
bExprs := b.subExprs()
if len(aExprs) == 0 && len(bExprs) == 0 {
return nil, errors.Errorf("Unification Fail. %v ~ %v cannot proceed", a, b)
}
if len(aExprs) != len(bExprs) {
return nil, errors.Errorf("Unification Fail. %v ~ %v cannot proceed as they do not contain the same amount of sub-expressions. %v has %d subexpressions while %v has %d subexpressions", a, b, a, len(aExprs), b, len(bExprs))
}
return unifyMany(aExprs, bExprs)
}
panic("TODO")
}
func unifyMany(as, bs []substitutableExpr) (ss substitutions, err error) {
for i, a := range as {
b := bs[i]
if len(ss) > 0 {
a = a.apply(ss).(substitutableExpr)
b = b.apply(ss).(substitutableExpr)
}
var s2 substitutions
if s2, err = unify(a, b); err != nil {
return nil, err
}
if ss == nil {
ss = s2
} else {
ss = compose(ss, s2)
}
}
return
}
// tmp solution
func eq(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
}
func bind(v Var, E substitutable) (substitutions, error) {
if occurs(v, E) {
return nil, errors.Errorf("Recursive unification")
}
return substitutions{{Sub: E.(Expr), For: v}}, nil
}
func occurs(v Var, in substitutable) bool {
vs := in.freevars()
return vs.Contains(v)
}