diff options
author | John Hodge <tpg@mutabah.net> | 2016-08-13 22:15:27 +0800 |
---|---|---|
committer | John Hodge <tpg@mutabah.net> | 2016-08-13 22:15:27 +0800 |
commit | 30567917626efb381e94ef719447db88d0d5685f (patch) | |
tree | 1440ed757c539a2456fed27e2be204343b3efcbf /src/mir/from_hir_match.cpp | |
parent | 4534f1e5acb1deaf1efccc0f91a8e0c09c163259 (diff) | |
download | mrust-30567917626efb381e94ef719447db88d0d5685f.tar.gz |
MIR Gen Match - Rewrite handling of DecisionTree for correctness
- Pattern entries record what field they were from
- Generation skips if the current test is not for that field
Diffstat (limited to 'src/mir/from_hir_match.cpp')
-rw-r--r-- | src/mir/from_hir_match.cpp | 289 |
1 files changed, 198 insertions, 91 deletions
diff --git a/src/mir/from_hir_match.cpp b/src/mir/from_hir_match.cpp index 8f1fd76a..5fda73a2 100644 --- a/src/mir/from_hir_match.cpp +++ b/src/mir/from_hir_match.cpp @@ -11,7 +11,9 @@ void MIR_LowerHIR_Match( MirBuilder& builder, MirConverter& conv, ::HIR::ExprNode_Match& node, ::MIR::LValue match_val ); -TAGGED_UNION(PatternRule, Any, +#define FIELD_DEREF 255 + +TAGGED_UNION_EX(PatternRule, (), Any,( // _ pattern (Any, struct {}), // Enum variant @@ -21,6 +23,12 @@ TAGGED_UNION(PatternRule, Any, // General value (Value, ::MIR::Constant), (ValueRange, struct { ::MIR::Constant first, last; }) + ), + ( , field_path(mv$(x.field_path)) ), (field_path = mv$(x.field_path);), + ( + typedef ::std::vector<uint8_t> field_path_t; + field_path_t field_path; + ) ); ::std::ostream& operator<<(::std::ostream& os, const PatternRule& x); /// Constructed set of rules from a pattern @@ -51,7 +59,10 @@ void MIR_LowerHIR_Match_DecisionTree( MirBuilder& builder, MirConverter& conv, : struct PatternRulesetBuilder { ::std::vector<PatternRule> m_rules; + PatternRule::field_path_t m_field_path; + void append_from(const Span& sp, const ::HIR::Pattern& pat, const ::HIR::TypeRef& ty); + void push_rule(PatternRule r); }; // -------------------------------------------------------------------- @@ -87,7 +98,7 @@ void MIR_LowerHIR_Match( MirBuilder& builder, MirConverter& conv, ::HIR::ExprNod pat_builder.append_from(node.span(), pat, node.m_value->m_res_type); arm_rules.push_back( PatternRuleset { arm_idx, pat_idx, mv$(pat_builder.m_rules) } ); - DEBUG("(" << arm_idx << "," << pat_idx << ") [" << arm_rules.back().m_rules << "]"); + DEBUG("(" << arm_idx << "," << pat_idx << ") " << pat << " ==> [" << arm_rules.back().m_rules << "]"); pat_idx += 1; } @@ -531,14 +542,21 @@ struct DecisionTreeNode (String, ::std::vector< ::std::pair< ::std::string, Branch> >) ); + // TODO: Arm specialisation bool is_specialisation; + PatternRule::field_path_t m_field_path; Values m_branches; Branch m_default; - DecisionTreeNode(): - is_specialisation(false) + DecisionTreeNode( PatternRule::field_path_t field_path ): + is_specialisation(false)/*, + m_field_path( mv$(field_path) ) // */ {} + static Branch new_branch_subtree(PatternRule::field_path_t path) { + return Branch( box$(DecisionTreeNode( mv$(path) )) ); + } + static Branch clone(const Branch& b); static Values clone(const Values& x); DecisionTreeNode clone() const; @@ -556,6 +574,18 @@ struct DecisionTreeNode /// HELPER: Unfies the rules from the provided branch with this node void unify_from(const Branch& b); + ::MIR::LValue get_field(const ::MIR::LValue& base) const { + ::MIR::LValue cur = base.clone(); + for(const auto idx : m_field_path) { + if( idx == FIELD_DEREF ) { + cur = ::MIR::LValue::make_Deref({ box$(cur) }); + } + else { + cur = ::MIR::LValue::make_Field({ box$(cur), idx }); + } + } + return cur; + } friend ::std::ostream& operator<<(::std::ostream& os, const Branch& x); friend ::std::ostream& operator<<(::std::ostream& os, const DecisionTreeNode& x); @@ -575,13 +605,14 @@ struct DecisionTreeGen return m_rule_blocks.at( rule_index ); } - void populate_tree_vals(const Span& sp, const DecisionTreeNode& node, const ::HIR::TypeRef& ty, const ::MIR::LValue& val) { - populate_tree_vals(sp, node, ty, 0, val, [](const auto& n){ DEBUG("final node = " << n); }); + void generate_tree_code(const Span& sp, const DecisionTreeNode& node, const ::HIR::TypeRef& ty, const ::MIR::LValue& val) { + generate_tree_code(sp, node, ty, 0, val, 0, [](const auto& n){ DEBUG("final node = " << n); }); } - void populate_tree_vals( + void generate_tree_code( const Span& sp, const DecisionTreeNode& node, - const ::HIR::TypeRef& ty, unsigned int ty_ofs, const ::MIR::LValue& val, + const ::HIR::TypeRef& ty, unsigned int ty_ofs, + const ::MIR::LValue& val, unsigned int depth, ::std::function<void(const DecisionTreeNode&)> and_then ); @@ -650,7 +681,7 @@ void MIR_LowerHIR_Match_DecisionTree( MirBuilder& builder, MirConverter& conv, : // - Build tree by running each arm's pattern across it DEBUG("- Building decision tree"); - DecisionTreeNode root_node; + DecisionTreeNode root_node({}); for( const auto& arm_rule : arm_rules ) { auto arm_idx = arm_rule.arm_idx; @@ -665,11 +696,16 @@ void MIR_LowerHIR_Match_DecisionTree( MirBuilder& builder, MirConverter& conv, : DEBUG("- Emitting decision tree"); DecisionTreeGen gen { builder, rule_blocks }; builder.set_cur_block( first_cmp_block ); - gen.populate_tree_vals( node.span(), root_node, node.m_value->m_res_type, mv$(match_val) ); + gen.generate_tree_code( node.span(), root_node, node.m_value->m_res_type, mv$(match_val) ); + assert( !builder.block_active() ); } ::std::ostream& operator<<(::std::ostream& os, const PatternRule& x) { + os <<"{"; + for(const auto idx : x.field_path) + os << "." << static_cast<unsigned int>(idx); + os << "}="; TU_MATCHA( (x), (e), (Any, os << "_"; @@ -747,8 +783,15 @@ bool PatternRuleset::is_before(const PatternRuleset& other) const } +void PatternRulesetBuilder::push_rule(PatternRule r) +{ + m_rules.push_back( mv$(r) ); + m_rules.back().field_path = m_field_path; +} + void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pat, const ::HIR::TypeRef& ty) { + TRACE_FUNCTION_F("pat="<<pat<<", ty="<<ty<<", m_field_path.size()=" <<m_field_path.size() << " " << (m_field_path.empty() ? 0 : m_field_path.back()) ); struct H { static uint64_t get_pattern_value_int(const Span& sp, const ::HIR::Pattern& pat, const ::HIR::Pattern::Value& val) { TU_MATCH_DEF( ::HIR::Pattern::Value, (val), (e), @@ -774,7 +817,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa TU_MATCH_DEF(::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Matching primitive with invalid pattern - " << pat); ), (Any, - m_rules.push_back( PatternRule::make_Any({}) ); + this->push_rule( PatternRule::make_Any({}) ); ), (Range, switch(e) @@ -790,7 +833,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa case ::HIR::CoreType::Usize: { uint64_t start = H::get_pattern_value_int(sp, pat, pe.start); uint64_t end = H::get_pattern_value_int(sp, pat, pe.end ); - m_rules.push_back( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) ); + this->push_rule( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) ); } break; case ::HIR::CoreType::I8: case ::HIR::CoreType::I16: @@ -799,7 +842,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa case ::HIR::CoreType::Isize: { int64_t start = H::get_pattern_value_int(sp, pat, pe.start); int64_t end = H::get_pattern_value_int(sp, pat, pe.end ); - m_rules.push_back( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) ); + this->push_rule( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) ); } break; case ::HIR::CoreType::Bool: BUG(sp, "Can't range match on Bool"); @@ -807,7 +850,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa case ::HIR::CoreType::Char: { uint64_t start = H::get_pattern_value_int(sp, pat, pe.start); uint64_t end = H::get_pattern_value_int(sp, pat, pe.end ); - m_rules.push_back( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) ); + this->push_rule( PatternRule::make_ValueRange( {::MIR::Constant(start), ::MIR::Constant(end)} ) ); } break; case ::HIR::CoreType::Str: BUG(sp, "Hit match over `str` - must be `&str`"); @@ -821,7 +864,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa case ::HIR::CoreType::F64: { TODO(sp, "Match over float, is it valid?"); //double val = pe.val.as_Float().value; - //m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) ); + //this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) ); } break; case ::HIR::CoreType::U8: case ::HIR::CoreType::U16: @@ -829,7 +872,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa case ::HIR::CoreType::U64: case ::HIR::CoreType::Usize: { uint64_t val = H::get_pattern_value_int(sp, pat, pe.val); - m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) ); + this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) ); } break; case ::HIR::CoreType::I8: case ::HIR::CoreType::I16: @@ -837,16 +880,16 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa case ::HIR::CoreType::I64: case ::HIR::CoreType::Isize: { int64_t val = H::get_pattern_value_int(sp, pat, pe.val); - m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) ); + this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) ); } break; case ::HIR::CoreType::Bool: // TODO: Support values from `const` too - m_rules.push_back( PatternRule::make_Bool( pe.val.as_Integer().value != 0 ) ); + this->push_rule( PatternRule::make_Bool( pe.val.as_Integer().value != 0 ) ); break; case ::HIR::CoreType::Char: { // Char is just another name for 'u32'... but with a restricted range uint64_t val = H::get_pattern_value_int(sp, pat, pe.val); - m_rules.push_back( PatternRule::make_Value( ::MIR::Constant(val) ) ); + this->push_rule( PatternRule::make_Value( ::MIR::Constant(val) ) ); } break; case ::HIR::CoreType::Str: BUG(sp, "Hit match over `str` - must be `&str`"); @@ -856,18 +899,24 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa ) ), (Tuple, + m_field_path.push_back(0); TU_MATCH_DEF(::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Matching tuple with invalid pattern - " << pat); ), (Any, - for(const auto& sty : e) + for(const auto& sty : e) { this->append_from(sp, pat, sty); + m_field_path.back() ++; + } ), (Tuple, assert(e.size() == pe.sub_patterns.size()); - for(unsigned int i = 0; i < e.size(); i ++) + for(unsigned int i = 0; i < e.size(); i ++) { this->append_from(sp, pe.sub_patterns[i], e[i]); + m_field_path.back() ++; + } ) ) + m_field_path.pop_back(); ), (Path, // This is either a struct destructure or an enum @@ -879,15 +928,11 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Matching opaque type with invalid pattern - " << pat); ), (Any, - m_rules.push_back( PatternRule::make_Any({}) ); + this->push_rule( PatternRule::make_Any({}) ); ) ) ), (Struct, - //auto monomorph_cb = [&](const auto& ty)->const auto& { - // const auto& ge = ty.m_data.as_Generic(); - // if( ge. - // }; auto monomorph = [&](const auto& ty) { return monomorphise_type(sp, pbe->m_params, e.path.m_data.as_Generic().m_params, ty); }; const auto& str_data = pbe->m_data; TU_MATCHA( (str_data), (sd), @@ -896,6 +941,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa ( BUG(sp, "Match not allowed, " << ty << " with " << pat); ), (Any, // Nothing. + this->push_rule( PatternRule::make_Any({}) ); ), (Value, TODO(sp, "Match over struct - Unit + Value"); @@ -917,18 +963,18 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa ) ), (Named, - // TODO: Avoid needing to clone everything. - ::std::vector< ::HIR::TypeRef> types; - types.reserve( sd.size() ); - for( const auto& fld : sd ) { - types.push_back( monomorph(fld.second.ent) ); - } - TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Match not allowed, " << ty << " with " << pat); ), (Any, - for(const auto& sty : types) - this->append_from(sp, pat, sty); + m_field_path.push_back(0); + for(const auto& fld : sd) + { + ::HIR::TypeRef tmp; + const auto& sty_mono = (monomorphise_type_needed(fld.second.ent) ? tmp = monomorph(fld.second.ent) : fld.second.ent); + this->append_from(sp, pat, sty_mono); + m_field_path.back() ++; + } + m_field_path.pop_back(); ), (Struct, TODO(sp, "Match over struct - Named + Struct"); @@ -942,26 +988,32 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Match not allowed, " << ty << " with " << pat); ), (Any, - m_rules.push_back( PatternRule::make_Any({}) ); + this->push_rule( PatternRule::make_Any({}) ); ), (EnumValue, - m_rules.push_back( PatternRule::make_Variant( {pe.binding_idx, {} } ) ); + this->push_rule( PatternRule::make_Variant( {pe.binding_idx, {} } ) ); ), (EnumTuple, const auto& var_def = pe.binding_ptr->m_variants.at(pe.binding_idx); const auto& fields_def = var_def.second.as_Tuple(); PatternRulesetBuilder sub_builder; + sub_builder.m_field_path.push_back(0); for( unsigned int i = 0; i < pe.sub_patterns.size(); i ++ ) { + sub_builder.m_field_path.back() = i; const auto& subpat = pe.sub_patterns[i]; - auto subty = monomorph(fields_def[i].ent); + const auto& ty_tpl = fields_def[i].ent; + + ::HIR::TypeRef tmp; + const auto& subty = (monomorphise_type_needed(ty_tpl) ? tmp = monomorph(ty_tpl) : ty_tpl); + sub_builder.append_from( sp, subpat, subty ); } - m_rules.push_back( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) ); + this->push_rule( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) ); ), (EnumTupleWildcard, - m_rules.push_back( PatternRule::make_Variant({ pe.binding_idx, {} }) ); + this->push_rule( PatternRule::make_Variant({ pe.binding_idx, {} }) ); ), (EnumStruct, const auto& var_def = pe.binding_ptr->m_variants.at(pe.binding_idx); @@ -979,8 +1031,11 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa } // 2. Iterate this list and recurse on the patterns PatternRulesetBuilder sub_builder; + sub_builder.m_field_path.push_back(0); for( unsigned int i = 0; i < tmp.size(); i ++ ) { + sub_builder.m_field_path.back() = i; + auto subty = monomorph(fields_def[i].second.ent); if( tmp[i] == ~0u ) { sub_builder.append_from( sp, ::HIR::Pattern(), subty ); @@ -990,7 +1045,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa sub_builder.append_from( sp, subpat, subty ); } } - m_rules.push_back( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) ); + this->push_rule( PatternRule::make_Variant({ pe.binding_idx, mv$(sub_builder.m_rules) }) ); ) ) ) @@ -1001,7 +1056,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Match not allowed, " << ty << " with " << pat); ), (Any, - m_rules.push_back( PatternRule::make_Any({}) ); + this->push_rule( PatternRule::make_Any({}) ); ) ) ), @@ -1024,6 +1079,7 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa } ), (Borrow, + m_field_path.push_back( FIELD_DEREF ); TU_MATCH_DEF( ::HIR::Pattern::Data, (pat.m_data), (pe), ( BUG(sp, "Matching borrow invalid pattern - " << pat); ), (Any, @@ -1035,13 +1091,14 @@ void PatternRulesetBuilder::append_from(const Span& sp, const ::HIR::Pattern& pa (Value, if( pe.val.is_String() ) { const auto& s = pe.val.as_String(); - m_rules.push_back( PatternRule::make_Value(s) ); + this->push_rule( PatternRule::make_Value(s) ); } else { BUG(sp, "Matching borrow invalid pattern - " << pat); } ) ) + m_field_path.pop_back(); ), (Pointer, if( pat.m_data.is_Any() ) { @@ -1117,7 +1174,7 @@ DecisionTreeNode::Values DecisionTreeNode::clone(const DecisionTreeNode::Values& throw ""; } DecisionTreeNode DecisionTreeNode::clone() const { - DecisionTreeNode rv; + DecisionTreeNode rv(m_field_path); rv.is_specialisation = is_specialisation; rv.m_branches = clone(m_branches); rv.m_default = clone(m_default); @@ -1128,6 +1185,13 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule assert( rule_count > 0 ); const auto& rule = *first_rule; + if( m_field_path.size() == 0 ) { + m_field_path = rule.field_path; + } + else { + assert( m_field_path == rule.field_path ); + } + TU_MATCHA( (rule), (e), (Any, { if( m_default.is_Unset() ) { @@ -1135,9 +1199,8 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule and_then(m_default); } else { - auto be = box$(DecisionTreeNode()); - be->populate_tree_from_rule(sp, first_rule+1, rule_count-1, and_then); - m_default = Branch(mv$(be)); + m_default = new_branch_subtree(rule.field_path); + m_default.as_Subtree()->populate_tree_from_rule(sp, first_rule+1, rule_count-1, and_then); } } else TU_IFLET( Branch, m_default, Subtree, be, @@ -1162,7 +1225,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule auto it = ::std::find_if( be.begin(), be.end(), [&](const auto& x){ return x.first >= e.idx; }); // If this variant isn't yet processed, add a new subtree for it if( it == be.end() || it->first != e.idx ) { - it = be.insert(it, ::std::make_pair(e.idx, Branch( box$(DecisionTreeNode()) ))); + it = be.insert(it, ::std::make_pair(e.idx, new_branch_subtree(rule.field_path))); assert( it->second.is_Subtree() ); } else { @@ -1178,7 +1241,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule { subtree.populate_tree_from_rule(sp, e.sub_rules.data(), e.sub_rules.size(), [&](auto& branch){ ASSERT_BUG(sp, branch.is_Unset(), "Duplicate terminator"); - branch = Branch( box$(DecisionTreeNode()) ); + branch = new_branch_subtree(rule.field_path); branch.as_Subtree()->populate_tree_from_rule(sp, first_rule+1, rule_count-1, and_then); }); } @@ -1206,7 +1269,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule auto& branch = (e ? be.true_branch : be.false_branch); if( branch.is_Unset() ) { - branch = Branch( box$( DecisionTreeNode() ) ); + branch = new_branch_subtree( rule.field_path ); } else if( branch.is_Terminal() ) { BUG(sp, "Duplicate terminal rule - " << branch.as_Terminal()); @@ -1239,7 +1302,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule auto& be = m_branches.as_Unsigned(); auto it = ::std::find_if(be.begin(), be.end(), [&](const auto& v){ return v.first.end >= ve; }); if( it == be.end() || it->first.start > ve ) { - it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve,ve }, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve,ve }, new_branch_subtree(rule.field_path) ) ); } else if( it->first.start == ve && it->first.end == ve ) { // Equal, continue and add sub-pat @@ -1280,7 +1343,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule auto it = ::std::find_if(be.begin(), be.end(), [&](const auto& v){ return v.first >= ve; }); if( it == be.end() || it->first != ve ) { - it = be.insert( it, ::std::make_pair(ve, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert( it, ::std::make_pair(ve, new_branch_subtree(rule.field_path) ) ); } auto& branch = it->second; if( rule_count > 1 ) @@ -1317,7 +1380,7 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule auto it = ::std::find_if(be.begin(), be.end(), [&](const auto& v){ return v.first >= ve_start || v.first.contains(ve_end); }); // If the end of the list was reached, OR the located entry sorts after the end of this range if( it == be.end() || it->first >= ve_end ) { - it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve_start,ve_end }, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert( it, ::std::make_pair( Range<uint64_t> { ve_start,ve_end }, new_branch_subtree(rule.field_path) ) ); } else if( it->first.start == ve_start && it->first.end == ve_end ) { // Equal, add sub-pattern @@ -1334,21 +1397,21 @@ void DecisionTreeNode::populate_tree_from_rule(const Span& sp, const PatternRule if( ve_start == it->first.start ) { // Add single range after it ++; - it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start + 1, ve_end }, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start + 1, ve_end }, new_branch_subtree(rule.field_path) ) ); } else if( ve_end == it->first.start ) { // Add single range before - it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, ve_end - 1 }, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, ve_end - 1 }, new_branch_subtree(rule.field_path) ) ); } else { // Add two ranges auto end1 = it->first.start - 1; auto start2 = it->first.end + 1; - it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, end1 }, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert(it, ::std::make_pair( Range<uint64_t> { ve_start, end1 }, new_branch_subtree(rule.field_path) ) ); auto& branch_1 = it->second; it ++; it ++; // Skip the original entry - it = be.insert(it, ::std::make_pair( Range<uint64_t> { start2, ve_end }, Branch( box$(DecisionTreeNode()) ) ) ); + it = be.insert(it, ::std::make_pair( Range<uint64_t> { start2, ve_end }, new_branch_subtree(rule.field_path) ) ); auto& branch_2 = it->second; if( rule_count > 1 ) @@ -1625,7 +1688,10 @@ void DecisionTreeNode::unify_from(const Branch& b) return os; } ::std::ostream& operator<<(::std::ostream& os, const DecisionTreeNode& x) { - os << "DTN { "; + os << "DTN ["; + for(const auto idx : x.m_field_path) + os << "." << static_cast<unsigned int>(idx); + os << "] { "; TU_MATCHA( (x.m_branches), (e), (Unset, os << "!, "; @@ -1634,11 +1700,13 @@ void DecisionTreeNode::unify_from(const Branch& b) os << "false = " << e.false_branch << ", true = " << e.true_branch << ", "; ), (Variant, + os << "V "; for(const auto& branch : e) { os << branch.first << " = " << branch.second << ", "; } ), (Unsigned, + os << "U "; for(const auto& branch : e) { const auto& range = branch.first; if( range.start == range.end ) { @@ -1651,6 +1719,7 @@ void DecisionTreeNode::unify_from(const Branch& b) } ), (Signed, + os << "S "; for(const auto& branch : e) { const auto& range = branch.first; if( range.start == range.end ) { @@ -1678,16 +1747,14 @@ void DecisionTreeNode::unify_from(const Branch& b) // ---------------------------- // DecisionTreeGen // ---------------------------- -void DecisionTreeGen::populate_tree_vals( +void DecisionTreeGen::generate_tree_code( const Span& sp, const DecisionTreeNode& node, - const ::HIR::TypeRef& ty, unsigned int ty_ofs, const ::MIR::LValue& val, + const ::HIR::TypeRef& ty, unsigned int ty_ofs, + const ::MIR::LValue& base_val, unsigned int depth, ::std::function<void(const DecisionTreeNode&)> and_then ) { - struct H { - }; - TRACE_FUNCTION_F("ty=" << ty << ", ty_ofs=" << ty_ofs << ", node=" << node); TU_MATCHA( (ty.m_data), (e), @@ -1717,7 +1784,7 @@ void DecisionTreeGen::populate_tree_vals( // Emit an if based on the route taken auto bb_false = m_builder.new_bb_unlinked(); auto bb_true = m_builder.new_bb_unlinked(); - m_builder.end_block( ::MIR::Terminator::make_If({ val.clone(), bb_true, bb_false }) ); + m_builder.end_block( ::MIR::Terminator::make_If({ node.get_field(base_val), bb_true, bb_false }) ); // Recurse into sub-patterns const auto& branch_false = ( !branches.false_branch.is_Unset() ? branches.false_branch : node.m_default ); @@ -1735,11 +1802,11 @@ void DecisionTreeGen::populate_tree_vals( case ::HIR::CoreType::U64: case ::HIR::CoreType::Usize: ASSERT_BUG(sp, node.m_branches.is_Unsigned(), "Tree for unsigned isn't a _Unsigned - node="<<node); - this->generate_branches_Unsigned(sp, node.m_default, node.m_branches.as_Unsigned(), ty, val, mv$(and_then)); + this->generate_branches_Unsigned(sp, node.m_default, node.m_branches.as_Unsigned(), ty, node.get_field(base_val), mv$(and_then)); break; case ::HIR::CoreType::Char: ASSERT_BUG(sp, node.m_branches.is_Unsigned(), "Tree for char isn't a _Unsigned - node="<<node); - this->generate_branches_Char(sp, node.m_default, node.m_branches.as_Unsigned(), ty, val, mv$(and_then)); + this->generate_branches_Char(sp, node.m_default, node.m_branches.as_Unsigned(), ty, node.get_field(base_val), mv$(and_then)); break; default: TODO(sp, "Primitive - " << ty); @@ -1748,15 +1815,21 @@ void DecisionTreeGen::populate_tree_vals( ), (Tuple, // Tuple - Recurse on each sub-type (increasing the index) + // - If complete, call completion callback if( ty_ofs == e.size() ) { and_then(node); } - else { - populate_tree_vals( sp, node, - e[ty_ofs], 0, ::MIR::LValue::make_Field({ box$(val.clone()), ty_ofs}), - [&](auto& n){ this->populate_tree_vals(sp, n, ty, ty_ofs+1, val, and_then); } + // - If the node is for this type, recurse + else if( node.m_field_path[depth] == ty_ofs ) { + generate_tree_code( sp, node, + e[ty_ofs], 0, base_val, depth+1, + [&](auto& n){ this->generate_tree_code(sp, n, ty, ty_ofs+1, base_val, depth, and_then); } ); } + // - Otherwise, go to the next node + else { + this->generate_tree_code(sp, node, ty, ty_ofs+1, base_val, depth, and_then); + } ), (Path, // This is either a struct destructure or an enum @@ -1768,12 +1841,51 @@ void DecisionTreeGen::populate_tree_vals( and_then(node); ), (Struct, - TODO(sp, "Match over struct - " << e.path); + auto monomorph = [&](const auto& ty) { return monomorphise_type(sp, pbe->m_params, e.path.m_data.as_Generic().m_params, ty); }; + TU_MATCHA( (pbe->m_data), (fields), + (Unit, + and_then(node); + ), + (Tuple, + if( ty_ofs == fields.size() ) { + and_then(node); + } + else if( node.m_field_path[depth] == ty_ofs ) { + const auto& fld = fields[ty_ofs]; + ::HIR::TypeRef tmp; + const auto& fld_ty = (monomorphise_type_needed(fld.ent) ? tmp = monomorph(fld.ent) : fld.ent); + generate_tree_code( sp, node, + fld_ty, 0, base_val, depth+1, + [&](auto& n){ this->generate_tree_code(sp, n, ty, ty_ofs+1, base_val, depth, and_then); } + ); + } + else { + this->generate_tree_code(sp, node, ty, ty_ofs+1, base_val, depth, and_then); + } + ), + (Named, + if( ty_ofs == fields.size() ) { + and_then(node); + } + else if( node.m_field_path[depth] == ty_ofs ) { + const auto& fld = fields[ty_ofs].second; + ::HIR::TypeRef tmp; + const auto& fld_ty = (monomorphise_type_needed(fld.ent) ? tmp = monomorph(fld.ent) : fld.ent); + generate_tree_code( sp, node, + fld_ty, 0, base_val, depth+1, + [&](auto& n){ this->generate_tree_code(sp, n, ty, ty_ofs+1, base_val, depth, and_then); } + ); + } + else { + this->generate_tree_code(sp, node, ty, ty_ofs+1, base_val, depth, and_then); + } + ) + ) ), (Enum, ASSERT_BUG(sp, node.m_branches.is_Variant(), "Tree for enum isn't a Variant - node="<<node); assert(pbe); - this->generate_branches_Enum(sp, node.m_default, node.m_branches.as_Variant(), ty, val, mv$(and_then)); + this->generate_branches_Enum(sp, node.m_default, node.m_branches.as_Variant(), ty, node.get_field(base_val), mv$(and_then)); ) ) ), @@ -1808,17 +1920,19 @@ void DecisionTreeGen::populate_tree_vals( assert( !branches.empty() ); for(const auto& branch : branches) { + auto have_val = node.get_field(base_val); + auto next_bb = (&branch == &branches.back() ? default_bb : m_builder.new_bb_unlinked()); auto test_val = m_builder.lvalue_or_temp( ::HIR::TypeRef(::HIR::CoreType::Str), ::MIR::Constant(branch.first) ); auto cmp_gt_bb = m_builder.new_bb_unlinked(); - auto lt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ val.clone(), ::MIR::eBinOp::LT, test_val.clone() }) ); + auto lt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ have_val.clone(), ::MIR::eBinOp::LT, test_val.clone() }) ); m_builder.end_block( ::MIR::Terminator::make_If({ mv$(lt_val), default_bb, cmp_gt_bb }) ); m_builder.set_cur_block(cmp_gt_bb); auto eq_bb = m_builder.new_bb_unlinked(); - auto gt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ val.clone(), ::MIR::eBinOp::GT, test_val.clone() }) ); + auto gt_val = m_builder.lvalue_or_temp( ::HIR::CoreType::Bool, ::MIR::RValue::make_BinOp({ mv$(have_val), ::MIR::eBinOp::GT, test_val.clone() }) ); m_builder.end_block( ::MIR::Terminator::make_If({ mv$(gt_val), next_bb, eq_bb }) ); m_builder.set_cur_block(eq_bb); @@ -1832,10 +1946,9 @@ void DecisionTreeGen::populate_tree_vals( TODO(sp, "Match over &[T]"); } else { - populate_tree_vals( sp, node, - *e.inner, 0, ::MIR::LValue::make_Deref({ box$(val.clone()) }), - and_then - ); + // TODO: Should this check if the index is -1? The assertion doesn't fire in general use, so looks good. + ASSERT_BUG( sp, node.m_field_path[depth] == FIELD_DEREF, "& not matching on deref " << depth << " node=" <<node ); + generate_tree_code( sp, node, *e.inner, 0, base_val, depth+1, and_then ); } ), (Pointer, @@ -2011,7 +2124,7 @@ void DecisionTreeGen::generate_branches_Enum( { auto bb = variant_blocks[branch.first]; const auto& var = variants[branch.first]; - DEBUG(branch.first << " " << var.first << " = " << branch); + DEBUG(branch.first << " " << var.first << " = " << branch.second); m_builder.set_cur_block( bb ); this->generate_branch(branch.second, [&](auto& subnode) { TU_MATCHA( (var.second), (e), @@ -2029,7 +2142,7 @@ void DecisionTreeGen::generate_branches_Enum( ents.push_back( monomorphise_type(sp, enum_ref.m_params, enum_path.m_params, fld.ent) ); } ::HIR::TypeRef fake_ty { mv$(ents) }; - this->populate_tree_vals(sp, subnode, fake_ty, 0, ::MIR::LValue::make_Downcast({ box$(val.clone()), branch.first }), and_then); + this->generate_tree_code(sp, subnode, fake_ty, 0, ::MIR::LValue::make_Downcast({ box$(val.clone()), branch.first }), 0, and_then); ), (Struct, TODO(sp, "Enum pattern - struct"); @@ -2038,16 +2151,10 @@ void DecisionTreeGen::generate_branches_Enum( }); } - DEBUG("_"); - TU_MATCHA( (default_branch), (be), - (Unset, ), - (Terminal, - m_builder.set_cur_block(any_block); - m_builder.end_block( ::MIR::Terminator::make_Goto( this->get_block_for_rule( be ) ) ); - ), - (Subtree, + DEBUG("_ = " << default_branch); + if( !default_branch.is_Unset() ) + { m_builder.set_cur_block(any_block); - and_then( *be ); - ) - ) + this->generate_branch(default_branch, and_then); + } } |