diff options
author | John Hodge <tpg@ucc.asn.au> | 2017-09-14 22:45:35 +0800 |
---|---|---|
committer | John Hodge <tpg@ucc.asn.au> | 2017-09-14 22:45:56 +0800 |
commit | 57da2d2bd12033aca1c38bbfc1dcf8d6c60b174f (patch) | |
tree | eea5a685f01a49c2de59bca2c93fd564daf0e8b2 /src/expand | |
parent | 419451d2b4e4f5f93c3abeb11235108c7e6bdeb6 (diff) | |
download | mrust-57da2d2bd12033aca1c38bbfc1dcf8d6c60b174f.tar.gz |
Expand - Limited derive on unions (minimally tested, fixes #22)
Diffstat (limited to 'src/expand')
-rw-r--r-- | src/expand/derive.cpp | 112 |
1 files changed, 86 insertions, 26 deletions
diff --git a/src/expand/derive.cpp b/src/expand/derive.cpp index 1ab1a7e2..0a74ce67 100644 --- a/src/expand/derive.cpp +++ b/src/expand/derive.cpp @@ -64,8 +64,12 @@ static inline AST::ExprNodeP mk_exprnodep(AST::ExprNode* en){ return AST::ExprNo /// Interface for derive handlers struct Deriver { + virtual const char* trait_name() const = 0; virtual AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const = 0; virtual AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Enum& enm) const = 0; + virtual AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const { + ERROR(sp, E0000, "Cannot derive(" << trait_name() << ") on union"); + } AST::GenericParams get_params_with_bounds(const Span& sp, const AST::GenericParams& p, const AST::Path& trait_path, ::std::vector<TypeRef> additional_bounded_types) const @@ -145,6 +149,15 @@ struct Deriver return ret; } + ::std::vector<TypeRef> get_field_bounds(const AST::Union& unn) const + { + ::std::vector<TypeRef> ret; + for( const auto& fld : unn.m_variants ) + { + add_field_bound_from_ty(unn.params(), ret, fld.m_type); + } + return ret; + } void add_field_bound_from_ty(const AST::GenericParams& params, ::std::vector<TypeRef>& out_list, const TypeRef& ty) const { @@ -294,6 +307,8 @@ class Deriver_Debug: } public: + const char* trait_name() const override { return "Debug"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { const ::std::string& name = type.path().nodes().back().name(); @@ -484,6 +499,8 @@ class Deriver_PartialEq: ); } public: + const char* trait_name() const override { return "PartialEq"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { ::std::vector<AST::ExprNodeP> nodes; @@ -617,7 +634,6 @@ public: class Deriver_PartialOrd: public Deriver { - AST::Path get_path(const ::std::string core_name, ::std::string c1, ::std::string c2) const { return AST::Path(core_name, { AST::PathNode(c1, {}), AST::PathNode(c2, {}) }); @@ -691,6 +707,8 @@ class Deriver_PartialOrd: ); } public: + const char* trait_name() const override { return "PartialOrd"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { ::std::vector<AST::ExprNodeP> nodes; @@ -905,6 +923,8 @@ class Deriver_Eq: } public: + const char* trait_name() const override { return "Eq"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { const AST::Path assert_method_path = this->get_trait_path(core_name) + "assert_receiver_is_total_eq"; @@ -993,12 +1013,25 @@ public: mv$(arms) )); } + + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const override + { + // Eq is just a marker, so it's valid to derive for union + const AST::Path assert_method_path = this->get_trait_path(core_name) + "assert_receiver_is_total_eq"; + ::std::vector<AST::ExprNodeP> nodes; + + for( const auto& fld : unn.m_variants ) + { + nodes.push_back( this->assert_is_eq(assert_method_path, this->field(fld.m_name)) ); + } + + return this->make_ret(sp, core_name, p, type, this->get_field_bounds(unn), NEWNODE(Block, mv$(nodes))); + } } g_derive_eq; class Deriver_Ord: public Deriver { - AST::Path get_path(const ::std::string core_name, ::std::string c1, ::std::string c2) const { return AST::Path(core_name, { AST::PathNode(c1, {}), AST::PathNode(c2, {}) }); @@ -1061,6 +1094,8 @@ class Deriver_Ord: return NEWNODE(NamedValue, this->get_path(core_name, "cmp", "Ordering", "Equal")); } public: + const char* trait_name() const override { return "Ord"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { ::std::vector<AST::ExprNodeP> nodes; @@ -1278,6 +1313,8 @@ class Deriver_Clone: } public: + const char* trait_name() const override { return "Clone"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { const AST::Path& ty_path = type.m_data.as_Path().path; @@ -1369,6 +1406,22 @@ public: mv$(arms) )); } + + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const override + { + // Clone on a union can only be a bitwise copy. (TODO: This requires Copy) + auto ret = this->make_ret(sp, core_name, p, type, this->get_field_bounds(unn), NEWNODE(Deref, + NEWNODE(NamedValue, AST::Path("self")) + )); + + for(auto& b : ret.def().params().bounds()) + { + auto& be = b.as_IsTrait(); + be.trait = AST::Path(core_name, { AST::PathNode("marker", {}), AST::PathNode("Copy", {}) }); + } + + return ret; + } } g_derive_clone; class Deriver_Copy: @@ -1389,6 +1442,8 @@ class Deriver_Copy: } public: + const char* trait_name() const override { return "Copy"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { return this->make_ret(sp, core_name, p, type, this->get_field_bounds(str), nullptr); @@ -1398,6 +1453,10 @@ public: { return this->make_ret(sp, core_name, p, type, this->get_field_bounds(enm), nullptr); } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const override + { + return this->make_ret(sp, core_name, p, type, this->get_field_bounds(unn), nullptr); + } } g_derive_copy; class Deriver_Default: @@ -1437,6 +1496,8 @@ class Deriver_Default: } public: + const char* trait_name() const override { return "Default"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { const AST::Path& ty_path = type.m_data.as_Path().path; @@ -1528,6 +1589,8 @@ class Deriver_Hash: } public: + const char* trait_name() const override { return "Hash"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { ::std::vector<AST::ExprNodeP> nodes; @@ -1687,6 +1750,8 @@ class Deriver_RustcEncodable: } public: + const char* trait_name() const override { return "RustcEncodable"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { const ::std::string& struct_name = type.m_data.as_Path().path.nodes().back().name(); @@ -1928,6 +1993,8 @@ class Deriver_RustcDecodable: } public: + const char* trait_name() const override { return "RustcDecodable"; } + AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override { AST::Path base_path = type.m_data.as_Path().path; @@ -2099,30 +2166,20 @@ public: // -------------------------------------------------------------------- static const Deriver* find_impl(const ::std::string& trait_name) { - if( trait_name == "Debug" ) - return &g_derive_debug; - else if( trait_name == "PartialEq" ) - return &g_derive_partialeq; - else if( trait_name == "PartialOrd" ) - return &g_derive_partialord; - else if( trait_name == "Eq" ) - return &g_derive_eq; - else if( trait_name == "Ord" ) - return &g_derive_ord; - else if( trait_name == "Clone" ) - return &g_derive_clone; - else if( trait_name == "Copy" ) - return &g_derive_copy; - else if( trait_name == "Default" ) - return &g_derive_default; - else if( trait_name == "Hash" ) - return &g_derive_hash; - else if( trait_name == "RustcEncodable" ) - return &g_derive_rustc_encodable; - else if( trait_name == "RustcDecodable" ) - return &g_derive_rustc_decodable; - else - return nullptr; + #define _(obj) if(trait_name == obj.trait_name()) return &obj; + _(g_derive_debug) + _(g_derive_partialeq) + _(g_derive_partialord) + _(g_derive_eq) + _(g_derive_ord) + _(g_derive_clone) + _(g_derive_copy) + _(g_derive_default) + _(g_derive_hash) + _(g_derive_rustc_encodable) + _(g_derive_rustc_decodable) + #undef _ + return nullptr; } template<typename T> @@ -2177,6 +2234,9 @@ public: (None, // ), + (Union, + derive_item(sp, crate, mod, attr, path, e); + ), (Enum, derive_item(sp, crate, mod, attr, path, e); ), |