summaryrefslogtreecommitdiff
path: root/src/expand
diff options
context:
space:
mode:
authorJohn Hodge <tpg@ucc.asn.au>2017-09-14 22:45:35 +0800
committerJohn Hodge <tpg@ucc.asn.au>2017-09-14 22:45:56 +0800
commit57da2d2bd12033aca1c38bbfc1dcf8d6c60b174f (patch)
treeeea5a685f01a49c2de59bca2c93fd564daf0e8b2 /src/expand
parent419451d2b4e4f5f93c3abeb11235108c7e6bdeb6 (diff)
downloadmrust-57da2d2bd12033aca1c38bbfc1dcf8d6c60b174f.tar.gz
Expand - Limited derive on unions (minimally tested, fixes #22)
Diffstat (limited to 'src/expand')
-rw-r--r--src/expand/derive.cpp112
1 files changed, 86 insertions, 26 deletions
diff --git a/src/expand/derive.cpp b/src/expand/derive.cpp
index 1ab1a7e2..0a74ce67 100644
--- a/src/expand/derive.cpp
+++ b/src/expand/derive.cpp
@@ -64,8 +64,12 @@ static inline AST::ExprNodeP mk_exprnodep(AST::ExprNode* en){ return AST::ExprNo
/// Interface for derive handlers
struct Deriver
{
+ virtual const char* trait_name() const = 0;
virtual AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const = 0;
virtual AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Enum& enm) const = 0;
+ virtual AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const {
+ ERROR(sp, E0000, "Cannot derive(" << trait_name() << ") on union");
+ }
AST::GenericParams get_params_with_bounds(const Span& sp, const AST::GenericParams& p, const AST::Path& trait_path, ::std::vector<TypeRef> additional_bounded_types) const
@@ -145,6 +149,15 @@ struct Deriver
return ret;
}
+ ::std::vector<TypeRef> get_field_bounds(const AST::Union& unn) const
+ {
+ ::std::vector<TypeRef> ret;
+ for( const auto& fld : unn.m_variants )
+ {
+ add_field_bound_from_ty(unn.params(), ret, fld.m_type);
+ }
+ return ret;
+ }
void add_field_bound_from_ty(const AST::GenericParams& params, ::std::vector<TypeRef>& out_list, const TypeRef& ty) const
{
@@ -294,6 +307,8 @@ class Deriver_Debug:
}
public:
+ const char* trait_name() const override { return "Debug"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
const ::std::string& name = type.path().nodes().back().name();
@@ -484,6 +499,8 @@ class Deriver_PartialEq:
);
}
public:
+ const char* trait_name() const override { return "PartialEq"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
::std::vector<AST::ExprNodeP> nodes;
@@ -617,7 +634,6 @@ public:
class Deriver_PartialOrd:
public Deriver
{
-
AST::Path get_path(const ::std::string core_name, ::std::string c1, ::std::string c2) const
{
return AST::Path(core_name, { AST::PathNode(c1, {}), AST::PathNode(c2, {}) });
@@ -691,6 +707,8 @@ class Deriver_PartialOrd:
);
}
public:
+ const char* trait_name() const override { return "PartialOrd"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
::std::vector<AST::ExprNodeP> nodes;
@@ -905,6 +923,8 @@ class Deriver_Eq:
}
public:
+ const char* trait_name() const override { return "Eq"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
const AST::Path assert_method_path = this->get_trait_path(core_name) + "assert_receiver_is_total_eq";
@@ -993,12 +1013,25 @@ public:
mv$(arms)
));
}
+
+ AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const override
+ {
+ // Eq is just a marker, so it's valid to derive for union
+ const AST::Path assert_method_path = this->get_trait_path(core_name) + "assert_receiver_is_total_eq";
+ ::std::vector<AST::ExprNodeP> nodes;
+
+ for( const auto& fld : unn.m_variants )
+ {
+ nodes.push_back( this->assert_is_eq(assert_method_path, this->field(fld.m_name)) );
+ }
+
+ return this->make_ret(sp, core_name, p, type, this->get_field_bounds(unn), NEWNODE(Block, mv$(nodes)));
+ }
} g_derive_eq;
class Deriver_Ord:
public Deriver
{
-
AST::Path get_path(const ::std::string core_name, ::std::string c1, ::std::string c2) const
{
return AST::Path(core_name, { AST::PathNode(c1, {}), AST::PathNode(c2, {}) });
@@ -1061,6 +1094,8 @@ class Deriver_Ord:
return NEWNODE(NamedValue, this->get_path(core_name, "cmp", "Ordering", "Equal"));
}
public:
+ const char* trait_name() const override { return "Ord"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
::std::vector<AST::ExprNodeP> nodes;
@@ -1278,6 +1313,8 @@ class Deriver_Clone:
}
public:
+ const char* trait_name() const override { return "Clone"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
const AST::Path& ty_path = type.m_data.as_Path().path;
@@ -1369,6 +1406,22 @@ public:
mv$(arms)
));
}
+
+ AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const override
+ {
+ // Clone on a union can only be a bitwise copy. (TODO: This requires Copy)
+ auto ret = this->make_ret(sp, core_name, p, type, this->get_field_bounds(unn), NEWNODE(Deref,
+ NEWNODE(NamedValue, AST::Path("self"))
+ ));
+
+ for(auto& b : ret.def().params().bounds())
+ {
+ auto& be = b.as_IsTrait();
+ be.trait = AST::Path(core_name, { AST::PathNode("marker", {}), AST::PathNode("Copy", {}) });
+ }
+
+ return ret;
+ }
} g_derive_clone;
class Deriver_Copy:
@@ -1389,6 +1442,8 @@ class Deriver_Copy:
}
public:
+ const char* trait_name() const override { return "Copy"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
return this->make_ret(sp, core_name, p, type, this->get_field_bounds(str), nullptr);
@@ -1398,6 +1453,10 @@ public:
{
return this->make_ret(sp, core_name, p, type, this->get_field_bounds(enm), nullptr);
}
+ AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Union& unn) const override
+ {
+ return this->make_ret(sp, core_name, p, type, this->get_field_bounds(unn), nullptr);
+ }
} g_derive_copy;
class Deriver_Default:
@@ -1437,6 +1496,8 @@ class Deriver_Default:
}
public:
+ const char* trait_name() const override { return "Default"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
const AST::Path& ty_path = type.m_data.as_Path().path;
@@ -1528,6 +1589,8 @@ class Deriver_Hash:
}
public:
+ const char* trait_name() const override { return "Hash"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
::std::vector<AST::ExprNodeP> nodes;
@@ -1687,6 +1750,8 @@ class Deriver_RustcEncodable:
}
public:
+ const char* trait_name() const override { return "RustcEncodable"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
const ::std::string& struct_name = type.m_data.as_Path().path.nodes().back().name();
@@ -1928,6 +1993,8 @@ class Deriver_RustcDecodable:
}
public:
+ const char* trait_name() const override { return "RustcDecodable"; }
+
AST::Impl handle_item(Span sp, const ::std::string& core_name, const AST::GenericParams& p, const TypeRef& type, const AST::Struct& str) const override
{
AST::Path base_path = type.m_data.as_Path().path;
@@ -2099,30 +2166,20 @@ public:
// --------------------------------------------------------------------
static const Deriver* find_impl(const ::std::string& trait_name)
{
- if( trait_name == "Debug" )
- return &g_derive_debug;
- else if( trait_name == "PartialEq" )
- return &g_derive_partialeq;
- else if( trait_name == "PartialOrd" )
- return &g_derive_partialord;
- else if( trait_name == "Eq" )
- return &g_derive_eq;
- else if( trait_name == "Ord" )
- return &g_derive_ord;
- else if( trait_name == "Clone" )
- return &g_derive_clone;
- else if( trait_name == "Copy" )
- return &g_derive_copy;
- else if( trait_name == "Default" )
- return &g_derive_default;
- else if( trait_name == "Hash" )
- return &g_derive_hash;
- else if( trait_name == "RustcEncodable" )
- return &g_derive_rustc_encodable;
- else if( trait_name == "RustcDecodable" )
- return &g_derive_rustc_decodable;
- else
- return nullptr;
+ #define _(obj) if(trait_name == obj.trait_name()) return &obj;
+ _(g_derive_debug)
+ _(g_derive_partialeq)
+ _(g_derive_partialord)
+ _(g_derive_eq)
+ _(g_derive_ord)
+ _(g_derive_clone)
+ _(g_derive_copy)
+ _(g_derive_default)
+ _(g_derive_hash)
+ _(g_derive_rustc_encodable)
+ _(g_derive_rustc_decodable)
+ #undef _
+ return nullptr;
}
template<typename T>
@@ -2177,6 +2234,9 @@ public:
(None,
//
),
+ (Union,
+ derive_item(sp, crate, mod, attr, path, e);
+ ),
(Enum,
derive_item(sp, crate, mod, attr, path, e);
),