summaryrefslogtreecommitdiff
path: root/src/mir/from_hir_match.cpp
diff options
context:
space:
mode:
authorJohn Hodge <tpg@mutabah.net>2016-08-13 22:15:27 +0800
committerJohn Hodge <tpg@mutabah.net>2016-08-13 22:15:27 +0800
commit30567917626efb381e94ef719447db88d0d5685f (patch)
tree1440ed757c539a2456fed27e2be204343b3efcbf /src/mir/from_hir_match.cpp
parent4534f1e5acb1deaf1efccc0f91a8e0c09c163259 (diff)
downloadmrust-30567917626efb381e94ef719447db88d0d5685f.tar.gz
MIR Gen Match - Rewrite handling of DecisionTree for correctness
- Pattern entries record what field they were from - Generation skips if the current test is not for that field
Diffstat (limited to 'src/mir/from_hir_match.cpp')
-rw-r--r--src/mir/from_hir_match.cpp289
1 files changed, 198 insertions, 91 deletions
diff --git a/src/mir/from_hir_match.cpp b/src/mir/from_hir_match.cpp
index 8f1fd76a..5fda73a2 100644
--- a/src/mir/from_hir_match.cpp
+++ b/src/mir/from_hir_match.cpp
@@ -11,7 +11,9 @@
void MIR_LowerHIR_Match( MirBuilder& builder, MirConverter& conv, ::HIR::ExprNode_Match& node, ::MIR::LValue match_val );
-TAGGED_UNION(PatternRule, Any,
+#define FIELD_DEREF 255
+
+TAGGED_UNION_EX(PatternRule, (), Any,(
// _ pattern
(Any, struct {}),
// Enum variant
@@ -21,6 +23,12 @@ TAGGED_UNION(PatternRule, Any,
// General value
(Value, ::MIR::Constant),
(ValueRange, struct { ::MIR::Constant first, last; })
+ ),
+ ( , field_path(mv$(x.field_path)) ), (field_path = mv$(x.field_path);),
+ (
+ typedef ::std::vector<uint8_t> field_path_t;
+ field_path_t field_path;
+ )
);
::std::ostream& operator<<(::std::ostream& os, const PatternRule& x);
/// Constructed set of rules from a pattern
@@ -51,7 +59,10 @@ void MIR_LowerHIR_Match_DecisionTree( MirBuilder& builder, MirConverter& conv, :
struct PatternRulesetBuilder
{
::std::vector<PatternRule> m_rules;
+ PatternRule::field_path_t m_field_path;
+
void append_from(const Span& sp, const ::HIR::Pattern& pat, const ::HIR::TypeRef& ty);
+ void push_rule(PatternRule r);
};
// --------------------------------------------------------------------
@@ -87,7 +98,7 @@ void MIR_LowerHIR_Match( MirBuilder& builder, MirConverter& conv, ::HIR::ExprNod
pat_builder.append_from(node.span(), pat, node.m_value->m_res_type);
arm_rules.push_back( PatternRuleset { arm_idx, pat_idx, mv$(pat_builder.m_rules) } );
- DEBUG("(" << arm_idx << "," << pat_idx << ") [" << arm_rules.back().m_rules << "]");
+ DEBUG("(" << arm_idx << "," << pat_idx << ") " << pat << " ==> [" << arm_rules.back().m_rules << "]");
pat_idx += 1;
}
@@ -531,14 +542,21 @@ struct DecisionTreeNode
(String, ::std::vector< ::std::pair< ::std::string, Branch> >)
);
+ // TODO: Arm specialisation
bool is_specialisation;
+ PatternRule::field_path_t m_field_path;
Values m_branches;
Branch m_default;
- DecisionTreeNode():
- is_specialisation(false)
+ DecisionTreeNode( PatternRule::field_path_t field_path ):
+ is_specialisation(false)/*,
+ m_field_path( mv$(field_path) ) // */
{}
+ static Branch new_branch_subtree(PatternRule::field_path_t path) {
+ return Branch( box$(DecisionTreeNode( mv$(path) )) );
+ }
+
static Branch clone(const Branch& b);
static Values clone(const Values& x);
DecisionTreeNode clone() const;
@@ -556,6 +574,18 @@ struct DecisionTreeNode
/// HELPER: Unfies the rules from the provided branch with this node
void unify_from(const Branch& b);
+ ::MIR::LValue get_field(const ::MIR::LValue& base) const {
+ ::MIR::LValue cur = base.clone();
+ for(const auto idx : m_field_path) {
+ if( idx == FIELD_DEREF ) {
+ cur = ::MIR::LValue::make_Deref({ box$(cur) });
+ }
+ else {
+ cur = ::MIR::LValue::make_Field({ box$(cur), idx });
+ }
+ }
+ return cur;
+ }
friend ::std::ostream& operator<<(::std::ostream& os, const Branch& x);
friend ::std::ostream& operator<<(::std::ostream& os, const DecisionTreeNode& x);
@@ -575,13 +605,14 @@ struct DecisionTreeGen
return m_rule_blocks.at( rule_index );
}
- void populate_tree_vals(const Span& sp, const DecisionTreeNode& node, const ::HIR::TypeRef& ty, const ::MIR::LValue& val) {
- populate_tree_vals(sp, node, ty, 0, val, [](const auto& n){ DEBUG("final node = " << n); });
+ void generate_tree_code(const Span& sp, const DecisionTreeNode& node, const ::HIR::TypeRef& ty, const ::MIR::LValue& val) {
+ generate_tree_code(sp, node, ty, 0, val, 0, [](const auto& n){ DEBUG("final node = " << n); });
}
- void populate_tree_vals(
+ void generate_tree_code(
const Span& sp,
const DecisionTreeNode& node,
- const ::HIR::TypeRef& ty, unsigned int ty_ofs, const ::MIR::LValue& val,
+ const ::HIR::TypeRef& ty, unsigned int ty_ofs,
+ const ::MIR::LValue& val, unsigned int depth,
::std::function<void(const DecisionTreeNode&)> and_then
);
@@ -650,7 +681,7 @@ void MIR_LowerHIR_Match_DecisionTree( MirBuilder& builder, MirConverter& conv, :
// - Build tree by running each arm's pattern across it
DEBUG("- Building decision tree");
- DecisionTreeNode root_node;
+ DecisionTreeNode root_node({});
for( const auto& arm_rule : arm_rules )
{
auto arm_idx = arm_rule.arm_idx;
@@ -665,11 +696,16 @@ void MIR_LowerHIR_Match_DecisionTree( MirBuilder& builder, MirConverter& conv, :
DEBUG("- Emitting decision tree");
DecisionTreeGen gen { builder, rule_blocks };
builder.set_cur_block( first_cmp_block );
- gen.populate_tree_vals( node.span(), root_node, node.m_value->m_res_type, mv$(match_val) );
+ gen.generate_tree_code( node.span(), root_node, node.m_value->m_res_type, mv$(match_val) );
+ assert( !builder.block_active() );
}
::std::ostream& operator<<(::std::ostream& os, const PatternRule& x)
{
+ os <<"{";
+ for(const auto idx : x.field_path)
+ os << "." << static_cast<unsigned int>(idx);
+ os << "}=";
TU_MATCHA( (x), (e),
(Any,
os << "_";
@@ -747,8 +783,15 @@ bool PatternRuleset::is_before(const PatternRuleset& other) const
}
+void PatternRulesetBuilder::push_rule(PatternRule r)
+{
+ m_rules.push_back( mv$(r) );
+ m_rules.back().field_path = m_field_path;
+}
+
void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pat, const ::HIR::TypeRef& ty)
{
+ TRACE_FUNCTION_F("pat="<<pat<<", ty="<<ty<<", m_field_path.size()=" <<m_field_path.size() << " " << (m_field_path.empty() ? 0 : m_field_path.back()) );
struct H {
static uint64_t get_pattern_value_int(const Span& sp, const ::HIR::Pattern& pat, const ::HIR::Pattern::Value& val) {
TU_MATCH_DEF( ::HIR::Pattern::Value, (val), (e),
@@ -774,7 +817,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
TU_MATCH_DEF(::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Matching primitive with invalid pattern - " << pat); ),
(Any,
- m_rules.push_back( PatternRule::make_Any({}) );
+ this->push_rule( PatternRule::make_Any({}) );
),
(Range,
switch(e)
@@ -790,7 +833,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
case ::HIR::CoreType::Usize: {
uint64_t start = H::get_pattern_value_int(sp, pat, pe.start);
uint64_t end = H::get_pattern_value_int(sp, pat, pe.end );
- m_rules.push_back( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) );
+ this->push_rule( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) );
} break;
case ::HIR::CoreType::I8:
case ::HIR::CoreType::I16:
@@ -799,7 +842,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
case ::HIR::CoreType::Isize: {
int64_t start = H::get_pattern_value_int(sp, pat, pe.start);
int64_t end = H::get_pattern_value_int(sp, pat, pe.end );
- m_rules.push_back( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) );
+ this->push_rule( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) );
} break;
case ::HIR::CoreType::Bool:
BUG(sp, "Can't range match on Bool");
@@ -807,7 +850,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
case ::HIR::CoreType::Char: {
uint64_t start = H::get_pattern_value_int(sp, pat, pe.start);
uint64_t end = H::get_pattern_value_int(sp, pat, pe.end );
- m_rules.push_back( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) );
+ this->push_rule( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) );
} break;
case ::HIR::CoreType::Str:
BUG(sp, "Hit match over `str` - must be `&str`");
@@ -821,7 +864,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
case ::HIR::CoreType::F64: {
TODO(sp, "Match over float, is it valid?");
//double val = pe.val.as_Float().value;
- //m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) );
+ //this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) );
} break;
case ::HIR::CoreType::U8:
case ::HIR::CoreType::U16:
@@ -829,7 +872,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
case ::HIR::CoreType::U64:
case ::HIR::CoreType::Usize: {
uint64_t val = H::get_pattern_value_int(sp, pat, pe.val);
- m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) );
+ this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) );
} break;
case ::HIR::CoreType::I8:
case ::HIR::CoreType::I16:
@@ -837,16 +880,16 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
case ::HIR::CoreType::I64:
case ::HIR::CoreType::Isize: {
int64_t val = H::get_pattern_value_int(sp, pat, pe.val);
- m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) );
+ this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) );
} break;
case ::HIR::CoreType::Bool:
// TODO: Support values from `const` too
- m_rules.push_back( PatternRule::make_Bool( pe.val.as_Integer().value != 0 ) );
+ this->push_rule( PatternRule::make_Bool( pe.val.as_Integer().value != 0 ) );
break;
case ::HIR::CoreType::Char: {
// Char is just another name for 'u32'... but with a restricted range
uint64_t val = H::get_pattern_value_int(sp, pat, pe.val);
- m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) );
+ this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) );
} break;
case ::HIR::CoreType::Str:
BUG(sp, "Hit match over `str` - must be `&str`");
@@ -856,18 +899,24 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
)
),
(Tuple,
+ m_field_path.push_back(0);
TU_MATCH_DEF(::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Matching tuple with invalid pattern - " << pat); ),
(Any,
- for(const auto& sty : e)
+ for(const auto& sty : e) {
this->append_from(sp, pat, sty);
+ m_field_path.back() ++;
+ }
),
(Tuple,
assert(e.size() == pe.sub_patterns.size());
- for(unsigned int i = 0; i < e.size(); i ++)
+ for(unsigned int i = 0; i < e.size(); i ++) {
this->append_from(sp, pe.sub_patterns[i], e[i]);
+ m_field_path.back() ++;
+ }
)
)
+ m_field_path.pop_back();
),
(Path,
// This is either a struct destructure or an enum
@@ -879,15 +928,11 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Matching opaque type with invalid pattern - " << pat); ),
(Any,
- m_rules.push_back( PatternRule::make_Any({}) );
+ this->push_rule( PatternRule::make_Any({}) );
)
)
),
(Struct,
- //auto monomorph_cb = [&](const auto& ty)->const auto& {
- // const auto& ge = ty.m_data.as_Generic();
- // if( ge.
- // };
auto monomorph = [&](const auto& ty) { return monomorphise_type(sp, pbe->m_params, e.path.m_data.as_Generic().m_params, ty); };
const auto& str_data = pbe->m_data;
TU_MATCHA( (str_data), (sd),
@@ -896,6 +941,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
( BUG(sp, "Match not allowed, " << ty << " with " << pat); ),
(Any,
// Nothing.
+ this->push_rule( PatternRule::make_Any({}) );
),
(Value,
TODO(sp, "Match over struct - Unit + Value");
@@ -917,18 +963,18 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
)
),
(Named,
- // TODO: Avoid needing to clone everything.
- ::std::vector< ::HIR::TypeRef> types;
- types.reserve( sd.size() );
- for( const auto& fld : sd ) {
- types.push_back( monomorph(fld.second.ent) );
- }
-
TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Match not allowed, " << ty << " with " << pat); ),
(Any,
- for(const auto& sty : types)
- this->append_from(sp, pat, sty);
+ m_field_path.push_back(0);
+ for(const auto& fld : sd)
+ {
+ ::HIR::TypeRef tmp;
+ const auto& sty_mono = (monomorphise_type_needed(fld.second.ent) ? tmp = monomorph(fld.second.ent) : fld.second.ent);
+ this->append_from(sp, pat, sty_mono);
+ m_field_path.back() ++;
+ }
+ m_field_path.pop_back();
),
(Struct,
TODO(sp, "Match over struct - Named + Struct");
@@ -942,26 +988,32 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Match not allowed, " << ty << " with " << pat); ),
(Any,
- m_rules.push_back( PatternRule::make_Any({}) );
+ this->push_rule( PatternRule::make_Any({}) );
),
(EnumValue,
- m_rules.push_back( PatternRule::make_Variant( {pe.binding_idx, {} } ) );
+ this->push_rule( PatternRule::make_Variant( {pe.binding_idx, {} } ) );
),
(EnumTuple,
const auto& var_def = pe.binding_ptr->m_variants.at(pe.binding_idx);
const auto& fields_def = var_def.second.as_Tuple();
PatternRulesetBuilder sub_builder;
+ sub_builder.m_field_path.push_back(0);
for( unsigned int i = 0; i < pe.sub_patterns.size(); i ++ )
{
+ sub_builder.m_field_path.back() = i;
const auto& subpat = pe.sub_patterns[i];
- auto subty = monomorph(fields_def[i].ent);
+ const auto& ty_tpl = fields_def[i].ent;
+
+ ::HIR::TypeRef tmp;
+ const auto& subty = (monomorphise_type_needed(ty_tpl) ? tmp = monomorph(ty_tpl) : ty_tpl);
+
sub_builder.append_from( sp, subpat, subty );
}
- m_rules.push_back( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) );
+ this->push_rule( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) );
),
(EnumTupleWildcard,
- m_rules.push_back( PatternRule::make_Variant({ pe.binding_idx, {} }) );
+ this->push_rule( PatternRule::make_Variant({ pe.binding_idx, {} }) );
),
(EnumStruct,
const auto& var_def = pe.binding_ptr->m_variants.at(pe.binding_idx);
@@ -979,8 +1031,11 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
}
// 2. Iterate this list and recurse on the patterns
PatternRulesetBuilder sub_builder;
+ sub_builder.m_field_path.push_back(0);
for( unsigned int i = 0; i < tmp.size(); i ++ )
{
+ sub_builder.m_field_path.back() = i;
+
auto subty = monomorph(fields_def[i].second.ent);
if( tmp[i] == ~0u ) {
sub_builder.append_from( sp, ::HIR::Pattern(), subty );
@@ -990,7 +1045,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
sub_builder.append_from( sp, subpat, subty );
}
}
- m_rules.push_back( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) );
+ this->push_rule( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) );
)
)
)
@@ -1001,7 +1056,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Match not allowed, " << ty << " with " << pat); ),
(Any,
- m_rules.push_back( PatternRule::make_Any({}) );
+ this->push_rule( PatternRule::make_Any({}) );
)
)
),
@@ -1024,6 +1079,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
}
),
(Borrow,
+ m_field_path.push_back( FIELD_DEREF );
TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe),
( BUG(sp, "Matching borrow invalid pattern - " << pat); ),
(Any,
@@ -1035,13 +1091,14 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa
(Value,
if( pe.val.is_String() ) {
const auto& s = pe.val.as_String();
- m_rules.push_back( PatternRule::make_Value(s) );
+ this->push_rule( PatternRule::make_Value(s) );
}
else {
BUG(sp, "Matching borrow invalid pattern - " << pat);
}
)
)
+ m_field_path.pop_back();
),
(Pointer,
if( pat.m_data.is_Any() ) {
@@ -1117,7 +1174,7 @@ DecisionTreeNode::Values DecisionTreeNode::clone(const DecisionTreeNode::Values&
throw "";
}
DecisionTreeNode DecisionTreeNode::clone() const {
- DecisionTreeNode rv;
+ DecisionTreeNode rv(m_field_path);
rv.is_specialisation = is_specialisation;
rv.m_branches = clone(m_branches);
rv.m_default = clone(m_default);
@@ -1128,6 +1185,13 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
assert( rule_count > 0 );
const auto& rule = *first_rule;
+ if( m_field_path.size() == 0 ) {
+ m_field_path = rule.field_path;
+ }
+ else {
+ assert( m_field_path == rule.field_path );
+ }
+
TU_MATCHA( (rule), (e),
(Any, {
if( m_default.is_Unset() ) {
@@ -1135,9 +1199,8 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
and_then(m_default);
}
else {
- auto be = box$(DecisionTreeNode());
- be->populate_tree_from_rule(sp, first_rule+1, rule_count-1, and_then);
- m_default = Branch(mv$(be));
+ m_default = new_branch_subtree(rule.field_path);
+ m_default.as_Subtree()->populate_tree_from_rule(sp, first_rule+1, rule_count-1, and_then);
}
}
else TU_IFLET( Branch, m_default, Subtree, be,
@@ -1162,7 +1225,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
auto it = ::std::find_if( be.begin(), be.end(), [&](const auto& x){ return x.first >= e.idx; });
// If this variant isn't yet processed, add a new subtree for it
if( it == be.end() || it->first != e.idx ) {
- it = be.insert(it, ::std::make_pair(e.idx, Branch( box$(DecisionTreeNode()) )));
+ it = be.insert(it, ::std::make_pair(e.idx, new_branch_subtree(rule.field_path)));
assert( it->second.is_Subtree() );
}
else {
@@ -1178,7 +1241,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
{
subtree.populate_tree_from_rule(sp, e.sub_rules.data(), e.sub_rules.size(), [&](auto& branch){
ASSERT_BUG(sp, branch.is_Unset(), "Duplicate terminator");
- branch = Branch( box$(DecisionTreeNode()) );
+ branch = new_branch_subtree(rule.field_path);
branch.as_Subtree()->populate_tree_from_rule(sp, first_rule+1, rule_count-1, and_then);
});
}
@@ -1206,7 +1269,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
auto& branch = (e ? be.true_branch : be.false_branch);
if( branch.is_Unset() ) {
- branch = Branch( box$( DecisionTreeNode() ) );
+ branch = new_branch_subtree( rule.field_path );
}
else if( branch.is_Terminal() ) {
BUG(sp, "Duplicate terminal rule - " << branch.as_Terminal());
@@ -1239,7 +1302,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
auto& be = m_branches.as_Unsigned();
auto it = ::std::find_if(be.begin(), be.end(), [&](const auto& v){ return v.first.end >= ve; });
if( it == be.end() || it->first.start > ve ) {
- it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve,ve }, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve,ve }, new_branch_subtree(rule.field_path) ) );
}
else if( it->first.start == ve && it->first.end == ve ) {
// Equal, continue and add sub-pat
@@ -1280,7 +1343,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
auto it = ::std::find_if(be.begin(), be.end(), [&](const auto& v){ return v.first >= ve; });
if( it == be.end() || it->first != ve ) {
- it = be.insert( it, ::std::make_pair(ve, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert( it, ::std::make_pair(ve, new_branch_subtree(rule.field_path) ) );
}
auto& branch = it->second;
if( rule_count > 1 )
@@ -1317,7 +1380,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
auto it = ::std::find_if(be.begin(), be.end(), [&](const auto& v){ return v.first >= ve_start || v.first.contains(ve_end); });
// If the end of the list was reached, OR the located entry sorts after the end of this range
if( it == be.end() || it->first >= ve_end ) {
- it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve_start,ve_end }, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve_start,ve_end }, new_branch_subtree(rule.field_path) ) );
}
else if( it->first.start == ve_start && it->first.end == ve_end ) {
// Equal, add sub-pattern
@@ -1334,21 +1397,21 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule
if( ve_start == it->first.start ) {
// Add single range after
it ++;
- it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start + 1, ve_end }, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start + 1, ve_end }, new_branch_subtree(rule.field_path) ) );
}
else if( ve_end == it->first.start ) {
// Add single range before
- it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, ve_end - 1 }, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, ve_end - 1 }, new_branch_subtree(rule.field_path) ) );
}
else {
// Add two ranges
auto end1 = it->first.start - 1;
auto start2 = it->first.end + 1;
- it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, end1 }, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, end1 }, new_branch_subtree(rule.field_path) ) );
auto& branch_1 = it->second;
it ++;
it ++; // Skip the original entry
- it = be.insert(it, ::std::make_pair( Range<uint64_t> { start2, ve_end }, Branch( box$(DecisionTreeNode()) ) ) );
+ it = be.insert(it, ::std::make_pair( Range<uint64_t> { start2, ve_end }, new_branch_subtree(rule.field_path) ) );
auto& branch_2 = it->second;
if( rule_count > 1 )
@@ -1625,7 +1688,10 @@ void DecisionTreeNode::unify_from(const Branch& b)
return os;
}
::std::ostream& operator<<(::std::ostream& os, const DecisionTreeNode& x) {
- os << "DTN { ";
+ os << "DTN [";
+ for(const auto idx : x.m_field_path)
+ os << "." << static_cast<unsigned int>(idx);
+ os << "] { ";
TU_MATCHA( (x.m_branches), (e),
(Unset,
os << "!, ";
@@ -1634,11 +1700,13 @@ void DecisionTreeNode::unify_from(const Branch& b)
os << "false = " << e.false_branch << ", true = " << e.true_branch << ", ";
),
(Variant,
+ os << "V ";
for(const auto& branch : e) {
os << branch.first << " = " << branch.second << ", ";
}
),
(Unsigned,
+ os << "U ";
for(const auto& branch : e) {
const auto& range = branch.first;
if( range.start == range.end ) {
@@ -1651,6 +1719,7 @@ void DecisionTreeNode::unify_from(const Branch& b)
}
),
(Signed,
+ os << "S ";
for(const auto& branch : e) {
const auto& range = branch.first;
if( range.start == range.end ) {
@@ -1678,16 +1747,14 @@ void DecisionTreeNode::unify_from(const Branch& b)
// ----------------------------
// DecisionTreeGen
// ----------------------------
-void DecisionTreeGen::populate_tree_vals(
+void DecisionTreeGen::generate_tree_code(
const Span& sp,
const DecisionTreeNode& node,
- const ::HIR::TypeRef& ty, unsigned int ty_ofs, const ::MIR::LValue& val,
+ const ::HIR::TypeRef& ty, unsigned int ty_ofs,
+ const ::MIR::LValue& base_val, unsigned int depth,
::std::function<void(const DecisionTreeNode&)> and_then
)
{
- struct H {
- };
-
TRACE_FUNCTION_F("ty=" << ty << ", ty_ofs=" << ty_ofs << ", node=" << node);
TU_MATCHA( (ty.m_data), (e),
@@ -1717,7 +1784,7 @@ void DecisionTreeGen::populate_tree_vals(
// Emit an if based on the route taken
auto bb_false = m_builder.new_bb_unlinked();
auto bb_true = m_builder.new_bb_unlinked();
- m_builder.end_block( ::MIR::Terminator::make_If({ val.clone(), bb_true, bb_false }) );
+ m_builder.end_block( ::MIR::Terminator::make_If({ node.get_field(base_val), bb_true, bb_false }) );
// Recurse into sub-patterns
const auto& branch_false = ( !branches.false_branch.is_Unset() ? branches.false_branch : node.m_default );
@@ -1735,11 +1802,11 @@ void DecisionTreeGen::populate_tree_vals(
case ::HIR::CoreType::U64:
case ::HIR::CoreType::Usize:
ASSERT_BUG(sp, node.m_branches.is_Unsigned(), "Tree for unsigned isn't a _Unsigned - node="<<node);
- this->generate_branches_Unsigned(sp, node.m_default, node.m_branches.as_Unsigned(), ty, val, mv$(and_then));
+ this->generate_branches_Unsigned(sp, node.m_default, node.m_branches.as_Unsigned(), ty, node.get_field(base_val), mv$(and_then));
break;
case ::HIR::CoreType::Char:
ASSERT_BUG(sp, node.m_branches.is_Unsigned(), "Tree for char isn't a _Unsigned - node="<<node);
- this->generate_branches_Char(sp, node.m_default, node.m_branches.as_Unsigned(), ty, val, mv$(and_then));
+ this->generate_branches_Char(sp, node.m_default, node.m_branches.as_Unsigned(), ty, node.get_field(base_val), mv$(and_then));
break;
default:
TODO(sp, "Primitive - " << ty);
@@ -1748,15 +1815,21 @@ void DecisionTreeGen::populate_tree_vals(
),
(Tuple,
// Tuple - Recurse on each sub-type (increasing the index)
+ // - If complete, call completion callback
if( ty_ofs == e.size() ) {
and_then(node);
}
- else {
- populate_tree_vals( sp, node,
- e[ty_ofs], 0, ::MIR::LValue::make_Field({ box$(val.clone()), ty_ofs}),
- [&](auto& n){ this->populate_tree_vals(sp, n, ty, ty_ofs+1, val, and_then); }
+ // - If the node is for this type, recurse
+ else if( node.m_field_path[depth] == ty_ofs ) {
+ generate_tree_code( sp, node,
+ e[ty_ofs], 0, base_val, depth+1,
+ [&](auto& n){ this->generate_tree_code(sp, n, ty, ty_ofs+1, base_val, depth, and_then); }
);
}
+ // - Otherwise, go to the next node
+ else {
+ this->generate_tree_code(sp, node, ty, ty_ofs+1, base_val, depth, and_then);
+ }
),
(Path,
// This is either a struct destructure or an enum
@@ -1768,12 +1841,51 @@ void DecisionTreeGen::populate_tree_vals(
and_then(node);
),
(Struct,
- TODO(sp, "Match over struct - " << e.path);
+ auto monomorph = [&](const auto& ty) { return monomorphise_type(sp, pbe->m_params, e.path.m_data.as_Generic().m_params, ty); };
+ TU_MATCHA( (pbe->m_data), (fields),
+ (Unit,
+ and_then(node);
+ ),
+ (Tuple,
+ if( ty_ofs == fields.size() ) {
+ and_then(node);
+ }
+ else if( node.m_field_path[depth] == ty_ofs ) {
+ const auto& fld = fields[ty_ofs];
+ ::HIR::TypeRef tmp;
+ const auto& fld_ty = (monomorphise_type_needed(fld.ent) ? tmp = monomorph(fld.ent) : fld.ent);
+ generate_tree_code( sp, node,
+ fld_ty, 0, base_val, depth+1,
+ [&](auto& n){ this->generate_tree_code(sp, n, ty, ty_ofs+1, base_val, depth, and_then); }
+ );
+ }
+ else {
+ this->generate_tree_code(sp, node, ty, ty_ofs+1, base_val, depth, and_then);
+ }
+ ),
+ (Named,
+ if( ty_ofs == fields.size() ) {
+ and_then(node);
+ }
+ else if( node.m_field_path[depth] == ty_ofs ) {
+ const auto& fld = fields[ty_ofs].second;
+ ::HIR::TypeRef tmp;
+ const auto& fld_ty = (monomorphise_type_needed(fld.ent) ? tmp = monomorph(fld.ent) : fld.ent);
+ generate_tree_code( sp, node,
+ fld_ty, 0, base_val, depth+1,
+ [&](auto& n){ this->generate_tree_code(sp, n, ty, ty_ofs+1, base_val, depth, and_then); }
+ );
+ }
+ else {
+ this->generate_tree_code(sp, node, ty, ty_ofs+1, base_val, depth, and_then);
+ }
+ )
+ )
),
(Enum,
ASSERT_BUG(sp, node.m_branches.is_Variant(), "Tree for enum isn't a Variant - node="<<node);
assert(pbe);
- this->generate_branches_Enum(sp, node.m_default, node.m_branches.as_Variant(), ty, val, mv$(and_then));
+ this->generate_branches_Enum(sp, node.m_default, node.m_branches.as_Variant(), ty, node.get_field(base_val), mv$(and_then));
)
)
),
@@ -1808,17 +1920,19 @@ void DecisionTreeGen::populate_tree_vals(
assert( !branches.empty() );
for(const auto& branch : branches)
{
+ auto have_val = node.get_field(base_val);
+
auto next_bb = (&branch == &branches.back() ? default_bb : m_builder.new_bb_unlinked());
auto test_val = m_builder.lvalue_or_temp( ::HIR::TypeRef(::HIR::CoreType::Str), ::MIR::Constant(branch.first) );
auto cmp_gt_bb = m_builder.new_bb_unlinked();
- auto lt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ val.clone(), ::MIR::eBinOp::LT, test_val.clone() }) );
+ auto lt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ have_val.clone(), ::MIR::eBinOp::LT, test_val.clone() }) );
m_builder.end_block( ::MIR::Terminator::make_If({ mv$(lt_val), default_bb, cmp_gt_bb }) );
m_builder.set_cur_block(cmp_gt_bb);
auto eq_bb = m_builder.new_bb_unlinked();
- auto gt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ val.clone(), ::MIR::eBinOp::GT, test_val.clone() }) );
+ auto gt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ mv$(have_val), ::MIR::eBinOp::GT, test_val.clone() }) );
m_builder.end_block( ::MIR::Terminator::make_If({ mv$(gt_val), next_bb, eq_bb }) );
m_builder.set_cur_block(eq_bb);
@@ -1832,10 +1946,9 @@ void DecisionTreeGen::populate_tree_vals(
TODO(sp, "Match over &[T]");
}
else {
- populate_tree_vals( sp, node,
- *e.inner, 0, ::MIR::LValue::make_Deref({ box$(val.clone()) }),
- and_then
- );
+ // TODO: Should this check if the index is -1? The assertion doesn't fire in general use, so looks good.
+ ASSERT_BUG( sp, node.m_field_path[depth] == FIELD_DEREF, "& not matching on deref " << depth << " node=" <<node );
+ generate_tree_code( sp, node, *e.inner, 0, base_val, depth+1, and_then );
}
),
(Pointer,
@@ -2011,7 +2124,7 @@ void DecisionTreeGen::generate_branches_Enum(
{
auto bb = variant_blocks[branch.first];
const auto& var = variants[branch.first];
- DEBUG(branch.first << " " << var.first << " = " << branch);
+ DEBUG(branch.first << " " << var.first << " = " << branch.second);
m_builder.set_cur_block( bb );
this->generate_branch(branch.second, [&](auto& subnode) {
TU_MATCHA( (var.second), (e),
@@ -2029,7 +2142,7 @@ void DecisionTreeGen::generate_branches_Enum(
ents.push_back( monomorphise_type(sp, enum_ref.m_params, enum_path.m_params, fld.ent) );
}
::HIR::TypeRef fake_ty { mv$(ents) };
- this->populate_tree_vals(sp, subnode, fake_ty, 0, ::MIR::LValue::make_Downcast({ box$(val.clone()), branch.first }), and_then);
+ this->generate_tree_code(sp, subnode, fake_ty, 0, ::MIR::LValue::make_Downcast({ box$(val.clone()), branch.first }), 0, and_then);
),
(Struct,
TODO(sp, "Enum pattern - struct");
@@ -2038,16 +2151,10 @@ void DecisionTreeGen::generate_branches_Enum(
});
}
- DEBUG("_");
- TU_MATCHA( (default_branch), (be),
- (Unset, ),
- (Terminal,
- m_builder.set_cur_block(any_block);
- m_builder.end_block( ::MIR::Terminator::make_Goto( this->get_block_for_rule( be ) ) );
- ),
- (Subtree,
+ DEBUG("_ = " << default_branch);
+ if( !default_branch.is_Unset() )
+ {
m_builder.set_cur_block(any_block);
- and_then( *be );
- )
- )
+ this->generate_branch(default_branch, and_then);
+ }
}