summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cmd/gc/swt.c215
1 files changed, 212 insertions, 3 deletions
diff --git a/src/cmd/gc/swt.c b/src/cmd/gc/swt.c
index fb3e4b095..5f014e9a9 100644
--- a/src/cmd/gc/swt.c
+++ b/src/cmd/gc/swt.c
@@ -11,6 +11,7 @@ enum
Sfalse,
Stype,
};
+Node* binarysw(Node *t, Iter *save, Node *name);
/*
* walktype
@@ -263,6 +264,7 @@ prepsw(Node *sw, int arg)
Iter save;
Node *name, *bool, *cas;
Node *t, *a;
+//dump("prepsw before", sw->nbody->left);
cas = N;
name = N;
@@ -279,7 +281,7 @@ prepsw(Node *sw, int arg)
loop:
if(t == N) {
sw->nbody->left = rev(cas);
-//dump("case", sw->nbody->left);
+//dump("prepsw after", sw->nbody->left);
return;
}
@@ -309,21 +311,29 @@ loop:
goto loop;
}
- a = nod(OIF, N, N);
- a->nbody = t->right; // then goto l
switch(arg) {
default:
// not bool const
+ a = binarysw(t, &save, name);
+ if(a != N)
+ break;
+
+ a = nod(OIF, N, N);
a->ntest = nod(OEQ, name, t->left); // if name == val
+ a->nbody = t->right; // then goto l
break;
case Strue:
+ a = nod(OIF, N, N);
a->ntest = t->left; // if val
+ a->nbody = t->right; // then goto l
break;
case Sfalse:
+ a = nod(OIF, N, N);
a->ntest = nod(ONOT, t->left, N); // if !val
+ a->nbody = t->right; // then goto l
break;
}
cas = list(cas, a);
@@ -470,3 +480,202 @@ walkswitch(Node *sw)
walkstate(sw->nbody);
//print("normal done\n");
}
+
+/*
+ * binary search on cases
+ */
+enum
+{
+ Ncase = 4, // needed to binary search
+};
+
+typedef struct Case Case;
+struct Case
+{
+ Node* node; // points at case statement
+ Case* link; // linked list to link
+};
+#define C ((Case*)nil)
+
+int
+iscaseconst(Node *t)
+{
+ if(t == N || t->left == N)
+ return 0;
+ switch(whatis(t->left)) {
+ case Wlitfloat:
+ case Wlitint:
+ case Wlitstr:
+ return 1;
+ }
+ return 0;
+}
+
+int
+countcase(Node *t, Iter save)
+{
+ int n;
+
+ // note that the iter is by value,
+ // so cases are not really consumed
+ for(n=0;; n++) {
+ if(!iscaseconst(t))
+ return n;
+ t = listnext(&save);
+ }
+}
+
+Case*
+csort(Case *l, int(*f)(Case*, Case*))
+{
+ Case *l1, *l2, *le;
+
+ if(l == C || l->link == C)
+ return l;
+
+ l1 = l;
+ l2 = l;
+ for(;;) {
+ l2 = l2->link;
+ if(l2 == C)
+ break;
+ l2 = l2->link;
+ if(l2 == C)
+ break;
+ l1 = l1->link;
+ }
+
+ l2 = l1->link;
+ l1->link = C;
+ l1 = csort(l, f);
+ l2 = csort(l2, f);
+
+ /* set up lead element */
+ if((*f)(l1, l2) < 0) {
+ l = l1;
+ l1 = l1->link;
+ } else {
+ l = l2;
+ l2 = l2->link;
+ }
+ le = l;
+
+ for(;;) {
+ if(l1 == C) {
+ while(l2) {
+ le->link = l2;
+ le = l2;
+ l2 = l2->link;
+ }
+ le->link = C;
+ break;
+ }
+ if(l2 == C) {
+ while(l1) {
+ le->link = l1;
+ le = l1;
+ l1 = l1->link;
+ }
+ break;
+ }
+ if((*f)(l1, l2) < 0) {
+ le->link = l1;
+ le = l1;
+ l1 = l1->link;
+ } else {
+ le->link = l2;
+ le = l2;
+ l2 = l2->link;
+ }
+ }
+ le->link = C;
+ return l;
+}
+
+int
+casecmp(Case *c1, Case *c2)
+{
+ int w;
+
+ w = whatis(c1->node->left);
+ if(w != whatis(c2->node->left))
+ fatal("casecmp1");
+
+ switch(w) {
+ case Wlitfloat:
+ return mpcmpfltflt(c1->node->left->val.u.fval, c2->node->left->val.u.fval);
+ case Wlitint:
+ return mpcmpfixfix(c1->node->left->val.u.xval, c2->node->left->val.u.xval);
+ case Wlitstr:
+ return cmpslit(c1->node->left, c2->node->left);
+ }
+
+ fatal("casecmp2");
+ return 0;
+}
+
+Node*
+constsw(Case *c0, int ncase, Node *name)
+{
+ Node *cas, *a;
+ Case *c;
+ int i, n;
+
+ // small number do sequentially
+ if(ncase < Ncase) {
+ cas = N;
+ for(i=0; i<ncase; i++) {
+ a = nod(OIF, N, N);
+ a->ntest = nod(OEQ, name, c0->node->left);
+ a->nbody = c0->node->right;
+ cas = list(cas, a);
+ c0 = c0->link;
+ }
+ return rev(cas);
+ }
+
+ // find center and recur
+ c = c0;
+ n = ncase>>1;
+ for(i=0; i<n; i++)
+ c = c->link;
+
+ a = nod(OIF, N, N);
+ a->ntest = nod(OLE, name, c->node->left);
+ a->nbody = constsw(c0, n+1, name); // include center
+ a->nelse = constsw(c->link, ncase-n-1, name); // exclude center
+ return a;
+}
+
+Node*
+binarysw(Node *t, Iter *save, Node *name)
+{
+ Case *c, *c1;
+ int i, ncase;
+ Node *a;
+
+ ncase = countcase(t, *save);
+ if(ncase < Ncase)
+ return N;
+
+ c = C;
+ for(i=1; i<ncase; i++) {
+ c1 = mal(sizeof(*c1));
+ c1->link = c;
+ c1->node = t;
+ c = c1;
+
+ t = listnext(save);
+ }
+
+ // last one shouldnt consume the iter
+ c1 = mal(sizeof(*c1));
+ c1->link = c;
+ c1->node = t;
+ c = c1;
+
+ c = csort(c, casecmp);
+ a = constsw(c, ncase, name);
+//dump("bin", a);
+ return a;
+}