summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJohn Hodge (sonata) <tpg@mutabah.net>2015-01-21 20:30:20 +0800
committerJohn Hodge (sonata) <tpg@mutabah.net>2015-01-21 20:30:20 +0800
commit8d1acfa3993e64b0266365379602799350855f3f (patch)
tree8ebd1c2c02a322714d2ea64776ac504f955d5fd6 /src
parent15284f127f4c622bf4d67d8d8c44e1799f84e7cb (diff)
downloadmrust-8d1acfa3993e64b0266365379602799350855f3f.tar.gz
Type propagation coming along
Diffstat (limited to 'src')
-rw-r--r--src/ast/ast.cpp52
-rw-r--r--src/ast/ast.hpp8
-rw-r--r--src/ast/expr.cpp14
-rw-r--r--src/ast/expr.hpp18
-rw-r--r--src/ast/path.cpp2
-rw-r--r--src/ast/path.hpp3
-rw-r--r--src/common.hpp16
-rw-r--r--src/convert/ast_iterate.cpp8
-rw-r--r--src/convert/typecheck_expr.cpp191
-rw-r--r--src/convert/typecheck_params.cpp8
-rw-r--r--src/types.cpp84
-rw-r--r--src/types.hpp5
12 files changed, 394 insertions, 15 deletions
diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp
index 21a8b58a..7fb84100 100644
--- a/src/ast/ast.cpp
+++ b/src/ast/ast.cpp
@@ -104,6 +104,12 @@ const Module& Crate::get_root_module(const ::std::string& name) const {
return it->second.root_module();
throw ParseError::Generic("crate name unknown");
}
+
+Function& Crate::lookup_method(const TypeRef& type, const char *name)
+{
+ throw ParseError::Generic( FMT("TODO: Lookup method "<<name<<" for type " <<type));
+}
+
void Crate::load_extern_crate(::std::string name)
{
::std::ifstream is("output/"+name+".ast");
@@ -281,6 +287,42 @@ SERIALISE_TYPE(Enum::, "AST_Enum", {
s.item(m_variants);
})
+TypeRef Struct::get_field_type(const char *name, const ::std::vector<TypeRef>& args)
+{
+ if( args.size() != m_params.n_params() )
+ {
+ throw ::std::runtime_error("Incorrect parameter count for struct");
+ }
+ // TODO: Should the bounds be checked here? Or is the count sufficient?
+ for(const auto& f : m_fields)
+ {
+ if( f.first == name )
+ {
+ // Found it!
+ if( args.size() )
+ {
+ TypeRef res = f.second;
+ res.resolve_args( [&](const char *argname){
+ for(unsigned int i = 0; i < m_params.n_params(); i ++)
+ {
+ if( m_params.params()[i].name() == argname ) {
+ return args.at(i);
+ }
+ }
+ throw ::std::runtime_error("BUGCHECK - Unknown arg in field type");
+ });
+ return res;
+ }
+ else
+ {
+ return f.second;
+ }
+ }
+ }
+
+ throw ::std::runtime_error(FMT("No such field " << name));
+}
+
SERIALISE_TYPE(Struct::, "AST_Struct", {
s << m_params;
s << m_fields;
@@ -336,6 +378,16 @@ SERIALISE_TYPE_S(GenericBound, {
s.item(m_trait);
})
+int TypeParams::find_name(const char* name) const
+{
+ for( unsigned int i = 0; i < m_params.size(); i ++ )
+ {
+ if( m_params[i].name() == name )
+ return i;
+ }
+ return -1;
+}
+
::std::ostream& operator<<(::std::ostream& os, const TypeParams& tps)
{
//return os << "TypeParams({" << tps.m_params << "}, {" << tps.m_bounds << "})";
diff --git a/src/ast/ast.hpp b/src/ast/ast.hpp
index b4ccf582..b56bd26a 100644
--- a/src/ast/ast.hpp
+++ b/src/ast/ast.hpp
@@ -98,6 +98,8 @@ public:
m_bounds.push_back( ::std::move(bound) );
}
+ int find_name(const char* name) const;
+
friend ::std::ostream& operator<<(::std::ostream& os, const TypeParams& tp);
SERIALISABLE_PROTOTYPES();
};
@@ -243,6 +245,8 @@ public:
{
}
+ const Class fcn_class() const { return m_fcn_class; }
+
TypeParams& params() { return m_params; }
Expr& code() { return m_code; }
TypeRef& rettype() { return m_rettype; }
@@ -326,6 +330,8 @@ public:
TypeParams& params() { return m_params; }
::std::vector<StructItem>& fields() { return m_fields; }
+ TypeRef get_field_type(const char *name, const ::std::vector<TypeRef>& args);
+
SERIALISABLE_PROTOTYPES();
};
@@ -493,6 +499,8 @@ public:
::std::map< ::std::string, ExternCrate>& extern_crates() { return m_extern_crates; }
const ::std::map< ::std::string, ExternCrate>& extern_crates() const { return m_extern_crates; }
+ Function& lookup_method(const TypeRef& type, const char *name);
+
void load_extern_crate(::std::string name);
void iterate_functions( fcn_visitor_t* visitor );
diff --git a/src/ast/expr.cpp b/src/ast/expr.cpp
index 026a54c1..1009709b 100644
--- a/src/ast/expr.cpp
+++ b/src/ast/expr.cpp
@@ -44,6 +44,8 @@ SERIALISE_TYPE(Expr::, "Expr", {
else _(ExprNode_Tuple)
else _(ExprNode_NamedValue)
else _(ExprNode_Field)
+ else _(ExprNode_Deref)
+ else _(ExprNode_Cast)
else _(ExprNode_CallPath)
else _(ExprNode_BinOp)
else
@@ -178,6 +180,13 @@ SERIALISE_TYPE_S(ExprNode_Field, {
s.item(m_name);
})
+void ExprNode_Deref::visit(NodeVisitor& nv) {
+ nv.visit(*this);
+}
+SERIALISE_TYPE_S(ExprNode_Deref, {
+ s.item(m_value);
+});
+
void ExprNode_Cast::visit(NodeVisitor& nv) {
nv.visit(*this);
}
@@ -325,6 +334,11 @@ void NodeVisitor::visit(ExprNode_Field& node)
DEBUG("DEF - ExprNode_Field");
visit(node.m_obj);
}
+void NodeVisitor::visit(ExprNode_Deref& node)
+{
+ DEBUG("DEF - ExprNode_Deref");
+ visit(node.m_value);
+}
void NodeVisitor::visit(ExprNode_Cast& node)
{
DEBUG("DEF - ExprNode_Cast");
diff --git a/src/ast/expr.hpp b/src/ast/expr.hpp
index ad2f906d..f1d16af0 100644
--- a/src/ast/expr.hpp
+++ b/src/ast/expr.hpp
@@ -293,6 +293,23 @@ struct ExprNode_Field:
SERIALISABLE_PROTOTYPES();
};
+// Pointer dereference
+struct ExprNode_Deref:
+ public ExprNode
+{
+ ::std::unique_ptr<ExprNode> m_value;
+
+ ExprNode_Deref() {}
+ ExprNode_Deref(::std::unique_ptr<ExprNode> value):
+ m_value( ::std::move(value) )
+ {
+ }
+
+ virtual void visit(NodeVisitor& nv) override;
+
+ SERIALISABLE_PROTOTYPES();
+};
+
// Type cast ('as')
struct ExprNode_Cast:
public ExprNode
@@ -369,6 +386,7 @@ public:
virtual void visit(ExprNode_NamedValue& node);
virtual void visit(ExprNode_Field& node);
+ virtual void visit(ExprNode_Deref& node);
virtual void visit(ExprNode_Cast& node);
virtual void visit(ExprNode_BinOp& node);
};
diff --git a/src/ast/path.cpp b/src/ast/path.cpp
index 6d5e50b2..6daaf406 100644
--- a/src/ast/path.cpp
+++ b/src/ast/path.cpp
@@ -126,7 +126,7 @@ void Path::resolve(const Crate& root_crate)
// - Maybe leave that up to other code?
if( is_last ) {
m_binding_type = ALIAS;
- m_binding.alias = &it->data;
+ m_binding.alias_ = &it->data;
return ;
}
else {
diff --git a/src/ast/path.hpp b/src/ast/path.hpp
index 4576ceec..5325f8ef 100644
--- a/src/ast/path.hpp
+++ b/src/ast/path.hpp
@@ -89,7 +89,7 @@ private:
const Enum* enum_;
unsigned int idx;
} enumvar;
- const TypeAlias* alias;
+ const TypeAlias* alias_;
} m_binding;
public:
Path():
@@ -163,6 +163,7 @@ public:
//_(Enum, enum, ENUM)
_(Function, func, FUNCTION)
_(Static, static, STATIC)
+ _(TypeAlias, alias, ALIAS)
#undef _
const Enum& bound_enum() const {
assert(m_binding_type == ENUM || m_binding_type == ENUM_VAR); // Kinda evil, given that it has its own union entry
diff --git a/src/common.hpp b/src/common.hpp
index a6f0717c..b9b90270 100644
--- a/src/common.hpp
+++ b/src/common.hpp
@@ -84,6 +84,22 @@ option<T> None() {
namespace std {
template <typename T>
+inline ::std::ostream& operator<<(::std::ostream& os, const ::std::vector<T*>& v) {
+ if( v.size() > 0 )
+ {
+ bool is_first = true;
+ for( const auto& i : v )
+ {
+ if(!is_first)
+ os << ", ";
+ is_first = false;
+ os << *i;
+ }
+ }
+ return os;
+}
+
+template <typename T>
inline ::std::ostream& operator<<(::std::ostream& os, const ::std::vector<T>& v) {
if( v.size() > 0 )
{
diff --git a/src/convert/ast_iterate.cpp b/src/convert/ast_iterate.cpp
index 4ed63b56..1c232b33 100644
--- a/src/convert/ast_iterate.cpp
+++ b/src/convert/ast_iterate.cpp
@@ -84,7 +84,10 @@ void CASTIterator::handle_pattern(AST::Pattern& pat, const TypeRef& type_hint)
if( pat.binding().size() > 0 )
{
// TODO: Mutable bindings
- local_variable( false, pat.binding(), type_hint );
+ if(pat.binding() != "_")
+ {
+ local_variable( false, pat.binding(), type_hint );
+ }
}
for( auto& subpat : pat.sub_patterns() )
handle_pattern(subpat, (const TypeRef&)TypeRef());
@@ -205,6 +208,9 @@ void CASTIterator::handle_trait(AST::Path path, AST::Trait& trait)
{
start_scope();
handle_params( trait.params() );
+
+ local_type("Self", TypeRef(path));
+
for( auto& fcn : trait.functions() )
handle_function( path + fcn.name, fcn.data );
end_scope();
diff --git a/src/convert/typecheck_expr.cpp b/src/convert/typecheck_expr.cpp
index 44d3056b..21ad40b8 100644
--- a/src/convert/typecheck_expr.cpp
+++ b/src/convert/typecheck_expr.cpp
@@ -13,16 +13,20 @@ class CTypeChecker:
struct Scope {
::std::vector< ::std::tuple<bool, ::std::string, TypeRef> > vars;
+ ::std::vector< ::std::tuple< ::std::string, TypeRef> > types;
};
+ AST::Crate& m_crate;
::std::vector<Scope> m_scopes;
-protected:
- TypeRef& get_local(const char* name);
- void lookup_method(const TypeRef& type, const char* name);
public:
+ CTypeChecker(AST::Crate& crate):
+ m_crate(crate)
+ {}
+
virtual void start_scope() override;
virtual void local_variable(bool is_mut, ::std::string name, const TypeRef& type) override;
+ virtual void local_type(::std::string name, TypeRef type) override;
virtual void end_scope() override;
virtual void handle_function(AST::Path path, AST::Function& fcn) override;
@@ -30,6 +34,11 @@ public:
virtual void handle_enum(AST::Path path, AST::Enum& ) override {}
virtual void handle_struct(AST::Path path, AST::Struct& str) override {}
virtual void handle_alias(AST::Path path, AST::TypeAlias& ) override {}
+
+private:
+ TypeRef& get_local_var(const char* name);
+ const TypeRef& get_local_type(const char* name);
+ void lookup_method(const TypeRef& type, const char* name);
};
class CTC_NodeVisitor:
public AST::NodeVisitor
@@ -47,6 +56,9 @@ public:
virtual void visit(AST::ExprNode_Match& node) override;
+ virtual void visit(AST::ExprNode_Field& node) override;
+ virtual void visit(AST::ExprNode_Cast& node) override;
+
virtual void visit(AST::ExprNode_CallMethod& node) override;
virtual void visit(AST::ExprNode_CallPath& node) override;
};
@@ -57,14 +69,20 @@ void CTypeChecker::start_scope()
}
void CTypeChecker::local_variable(bool is_mut, ::std::string name, const TypeRef& type)
{
+ DEBUG("is_mut=" << is_mut << " name=" << name << " type=" << type);
m_scopes.back().vars.push_back( make_tuple(is_mut, name, TypeRef(type)) );
}
+void CTypeChecker::local_type(::std::string name, TypeRef type)
+{
+ DEBUG("name=" << name << " type=" << type);
+ m_scopes.back().types.push_back( make_tuple(name, ::std::move(type)) );
+}
void CTypeChecker::end_scope()
{
m_scopes.pop_back();
}
-TypeRef& CTypeChecker::get_local(const char* name)
+TypeRef& CTypeChecker::get_local_var(const char* name)
{
for( auto it = m_scopes.end(); it-- != m_scopes.begin(); )
{
@@ -76,7 +94,21 @@ TypeRef& CTypeChecker::get_local(const char* name)
}
}
}
- throw ::std::runtime_error(FMT("get_local - name " << name << " not found"));
+ throw ::std::runtime_error(FMT("get_local_type - name " << name << " not found"));
+}
+const TypeRef& CTypeChecker::get_local_type(const char* name)
+{
+ for( auto it = m_scopes.end(); it-- != m_scopes.begin(); )
+ {
+ for( auto it2 = it->types.end(); it2-- != it->types.begin(); )
+ {
+ if( name == ::std::get<0>(*it2) )
+ {
+ return ::std::get<1>(*it2);
+ }
+ }
+ }
+ throw ::std::runtime_error(FMT("get_local_type - name " << name << " not found"));
}
void CTypeChecker::handle_function(AST::Path path, AST::Function& fcn)
@@ -88,6 +120,21 @@ void CTypeChecker::handle_function(AST::Path path, AST::Function& fcn)
handle_type(fcn.rettype());
+ switch(fcn.fcn_class())
+ {
+ case AST::Function::CLASS_UNBOUND:
+ break;
+ case AST::Function::CLASS_REFMETHOD:
+ local_variable(false, "self", TypeRef(TypeRef::TagReference(), false, get_local_type("Self")));
+ break;
+ case AST::Function::CLASS_MUTMETHOD:
+ local_variable(false, "self", TypeRef(TypeRef::TagReference(), true, get_local_type("Self")));
+ break;
+ case AST::Function::CLASS_VALMETHOD:
+ local_variable(true, "self", TypeRef(get_local_type("Self")));
+ break;
+ }
+
for( auto& arg : fcn.args() )
{
handle_type(arg.second);
@@ -157,8 +204,9 @@ void CTC_NodeVisitor::visit(AST::ExprNode_NamedValue& node)
}
else
{
- TypeRef& local_type = m_tc.get_local( p[0].name().c_str() );
+ TypeRef& local_type = m_tc.get_local_var( p[0].name().c_str() );
node.get_res_type().merge_with( local_type );
+ DEBUG("res type = " << node.get_res_type());
local_type = node.get_res_type();
}
}
@@ -215,6 +263,62 @@ void CTC_NodeVisitor::visit(AST::ExprNode_Match& node)
}
}
+void CTC_NodeVisitor::visit(AST::ExprNode_Field& node)
+{
+ DEBUG("ExprNode_Field " << node.m_name);
+
+ AST::NodeVisitor::visit(node.m_obj);
+
+ TypeRef* tr = &node.m_obj->get_res_type();
+ DEBUG("ExprNode_Field - tr = " << *tr);
+ if( tr->is_concrete() )
+ {
+ // Must be a structure type (what about associated items?)
+ unsigned int deref_count = 0;
+ while( tr->is_reference() )
+ {
+ tr = &tr->sub_types()[0];
+ DEBUG("ExprNode_Field - ref deref to " << *tr);
+ deref_count ++;
+ }
+ if( !tr->is_path() )
+ {
+ throw ::std::runtime_error("ExprNode_Field - Type not a path");
+ }
+
+ // TODO Move this logic to types.cpp?
+ const AST::Path& p = tr->path();
+ switch( p.binding_type() )
+ {
+ case AST::Path::STRUCT: {
+ const AST::PathNode& lastnode = p.nodes().back();
+ AST::Struct& s = const_cast<AST::Struct&>( p.bound_struct() );
+ node.get_res_type().merge_with( s.get_field_type(node.m_name.c_str(), lastnode.args()) );
+ break; }
+ default:
+ throw ::std::runtime_error("TODO: Get field from non-structure");
+ }
+ DEBUG("deref_count = " << deref_count);
+ for( unsigned i = 0; i < deref_count; i ++ )
+ {
+ node.m_obj = ::std::unique_ptr<AST::ExprNode>(new AST::ExprNode_Deref( ::std::move(node.m_obj) ));
+ }
+ }
+ else
+ {
+ DEBUG("ExprNode_Field - Type not concrete, can't get field");
+ }
+}
+
+void CTC_NodeVisitor::visit(AST::ExprNode_Cast& node)
+{
+ DEBUG("ExprNode_Cast " << node.m_type);
+
+ AST::NodeVisitor::visit(node.m_value);
+
+ node.get_res_type().merge_with( node.m_type );
+}
+
void CTC_NodeVisitor::visit(AST::ExprNode_CallMethod& node)
{
DEBUG("ExprNode_CallMethod " << node.m_method);
@@ -228,32 +332,95 @@ void CTC_NodeVisitor::visit(AST::ExprNode_CallMethod& node)
// Locate method
const TypeRef& type = node.m_val->get_res_type();
- DEBUG("- type = " << type);
+ DEBUG("CallMethod - type = " << type);
if( type.is_wildcard() )
{
// No idea (yet)
// - TODO: Support case where a trait is known
+ throw ::std::runtime_error("Unknown type in CallMethod");
}
else
{
// - Search for a method on this type
- //const Function& fcn = type.lookup_method(node.m_method.name());
+ AST::Function& fcn = m_tc.m_crate.lookup_method(type, node.m_method.name().c_str());
+ if( fcn.params().n_params() != node.m_method.args().size() )
+ {
+ throw ::std::runtime_error("TODO: CallMethod with param count mismatch");
+ }
+ if( fcn.params().n_params() )
+ {
+ throw ::std::runtime_error("TODO: CallMethod with params");
+ }
+ node.get_res_type().merge_with( fcn.rettype() );
}
}
void CTC_NodeVisitor::visit(AST::ExprNode_CallPath& node)
{
DEBUG("ExprNode_CallPath - " << node.m_path);
+ ::std::vector<TypeRef> argtypes;
for( auto& arg : node.m_args )
{
AST::NodeVisitor::visit(arg);
+ argtypes.push_back( arg->get_res_type() );
}
- if(node.m_path.binding_type() == AST::Path::FUNCTION) {
+ if(node.m_path.binding_type() == AST::Path::FUNCTION)
+ {
const AST::Function& fcn = node.m_path.bound_func();
+
+ if( fcn.params().n_params() > 0 )
+ {
+ throw ::std::runtime_error("CallPath - TODO: Params on functions");
+ }
+
+ DEBUG("ExprNode_CallPath - rt = " << fcn.rettype());
+ node.get_res_type().merge_with( fcn.rettype() );
}
- else if(node.m_path.binding_type() == AST::Path::ENUM_VAR) {
- const AST::Enum& emn = node.m_path.bound_enum();
+ else if(node.m_path.binding_type() == AST::Path::ENUM_VAR)
+ {
+ const AST::Enum& enm = node.m_path.bound_enum();
+ const unsigned int idx = node.m_path.bound_idx();
+ const auto& var = enm.variants().at(idx);
+
+ const auto& params = enm.params();
// We know the enum, but it might have type params, need to handle that case
+
+ if( params.n_params() > 0 )
+ {
+ // 1. Obtain the pattern set from the path (should it be pre-marked with _ types?)
+ auto& path_args = node.m_path[node.m_path.size()-2].args();
+ while( path_args.size() < params.n_params() )
+ path_args.push_back( TypeRef() );
+ DEBUG("path_args = [" << path_args << "]");
+ // 2. Create a pattern from the argument types and the format of the variant
+ DEBUG("argtypes = [" << argtypes << "]");
+ ::std::vector<TypeRef> item_args(enm.params().n_params());
+ DEBUG("variant type = " << var.second << "");
+ var.second.match_args(
+ TypeRef(TypeRef::TagTuple(), argtypes),
+ [&](const char *name, const TypeRef& t) {
+ DEBUG("Binding " << name << " to type " << t);
+ int idx = params.find_name(name);
+ if( idx == -1 ) {
+ throw ::std::runtime_error(FMT("Can't find generic " << name));
+ }
+ item_args.at(idx).merge_with( t );
+ });
+ DEBUG("item_args = [" << item_args << "]");
+ // 3. Merge the two sets of arguments
+ for( unsigned int i = 0; i < path_args.size(); i ++ )
+ {
+ path_args[i].merge_with( item_args[i] );
+ }
+ DEBUG("new path_args = [" << path_args << "]");
+ }
+
+ AST::Path p = node.m_path;
+ p.nodes().pop_back();
+ TypeRef ty( ::std::move(p) );
+
+ DEBUG("ExprNode_CallPath - enum t = " << ty);
+ node.get_res_type().merge_with(ty);
}
else
{
@@ -264,7 +431,7 @@ void CTC_NodeVisitor::visit(AST::ExprNode_CallPath& node)
void Typecheck_Expr(AST::Crate& crate)
{
DEBUG(" >>>");
- CTypeChecker tc;
+ CTypeChecker tc(crate);
tc.handle_module(AST::Path({}), crate.root_module());
DEBUG(" <<<");
}
diff --git a/src/convert/typecheck_params.cpp b/src/convert/typecheck_params.cpp
index 8d757c8c..7a682ce7 100644
--- a/src/convert/typecheck_params.cpp
+++ b/src/convert/typecheck_params.cpp
@@ -1,4 +1,5 @@
/*
+/// Typecheck generic parameters (ensure that they match all generic bounds)
*/
#include <main_bindings.hpp>
#include "ast_iterate.hpp"
@@ -180,6 +181,8 @@ void CGenericParamChecker::handle_path(AST::Path& path, CASTIterator::PathMode p
const AST::TypeParams* params = nullptr;
switch(path.binding_type())
{
+ case AST::Path::UNBOUND:
+ throw ::std::runtime_error("CGenericParamChecker::handle_path - Unbound path");
case AST::Path::MODULE:
DEBUG("WTF - Module path, isn't this invalid at this stage?");
break;
@@ -192,6 +195,9 @@ void CGenericParamChecker::handle_path(AST::Path& path, CASTIterator::PathMode p
case AST::Path::ENUM:
params = &path.bound_enum().params();
if(0)
+ case AST::Path::ALIAS:
+ params = &path.bound_alias().params();
+ if(0)
case AST::Path::FUNCTION:
params = &path.bound_func().params();
@@ -203,6 +209,8 @@ void CGenericParamChecker::handle_path(AST::Path& path, CASTIterator::PathMode p
throw ::std::runtime_error( FMT("Checking '" << path << "', threw : " << e.what()) );
}
break;
+ default:
+ throw ::std::runtime_error("Unknown path type in CGenericParamChecker::handle_path");
}
}
diff --git a/src/types.cpp b/src/types.cpp
index 384c0e6e..7445dd03 100644
--- a/src/types.cpp
+++ b/src/types.cpp
@@ -82,6 +82,90 @@ void TypeRef::merge_with(const TypeRef& other)
}
}
+/// Resolve all Generic/Argument types to the value returned by the passed closure
+void TypeRef::resolve_args(::std::function<TypeRef(const char*)> fcn)
+{
+ switch(m_class)
+ {
+ case TypeRef::ANY:
+ // TODO: Is resolving args on an ANY an erorr?
+ break;
+ case TypeRef::UNIT:
+ case TypeRef::PRIMITIVE:
+ break;
+ case TypeRef::TUPLE:
+ case TypeRef::REFERENCE:
+ case TypeRef::POINTER:
+ case TypeRef::ARRAY:
+ for( auto& t : m_inner_types )
+ t.resolve_args(fcn);
+ case TypeRef::GENERIC:
+ *this = fcn(m_path[0].name().c_str());
+ break;
+ case TypeRef::PATH:
+ for(auto& n : m_path.nodes())
+ {
+ for(auto& p : n.args())
+ p.resolve_args(fcn);
+ }
+ break;
+ case TypeRef::ASSOCIATED:
+ for(auto& t : m_inner_types )
+ t.resolve_args(fcn);
+ break;
+ }
+}
+
+void TypeRef::match_args(const TypeRef& other, ::std::function<void(const char*,const TypeRef&)> fcn) const
+{
+ // If the other type is a wildcard, early return
+ // - TODO - Might want to restrict the other type to be of the same form as this type
+ if( other.m_class == TypeRef::ANY )
+ return;
+ // If this type is a generic, then call the closure with the other type
+ if( m_class == TypeRef::GENERIC ) {
+ fcn( m_path[0].name().c_str(), other );
+ return ;
+ }
+
+ // Any other case, it's a "pattern" match
+ if( m_class != other.m_class )
+ throw ::std::runtime_error("Type mismatch (class)");
+ switch(m_class)
+ {
+ case TypeRef::ANY:
+ // Wait, isn't this an error?
+ throw ::std::runtime_error("Encountered '_' in match_args");
+ case TypeRef::UNIT:
+ break;
+ case TypeRef::PRIMITIVE:
+ // TODO: Should check if the type matches
+ if( m_core_type != other.m_core_type )
+ throw ::std::runtime_error("Type mismatch (core)");
+ break;
+ case TypeRef::TUPLE:
+ if( m_inner_types.size() != other.m_inner_types.size() )
+ throw ::std::runtime_error("Type mismatch (tuple size)");
+ for(unsigned int i = 0; i < m_inner_types.size(); i ++ )
+ m_inner_types[i].match_args( other.m_inner_types[i], fcn );
+ break;
+ case TypeRef::REFERENCE:
+ case TypeRef::POINTER:
+ if( m_is_inner_mutable != other.m_is_inner_mutable )
+ throw ::std::runtime_error("Type mismatch (inner mutable)");
+ m_inner_types[0].match_args( other.m_inner_types[0], fcn );
+ break;
+ case TypeRef::ARRAY:
+ throw ::std::runtime_error("TODO: TypeRef::match_args on ARRAY");
+ case TypeRef::GENERIC:
+ throw ::std::runtime_error("Encountered GENERIC in match_args");
+ case TypeRef::PATH:
+ throw ::std::runtime_error("TODO: TypeRef::match_args on PATH");
+ case TypeRef::ASSOCIATED:
+ throw ::std::runtime_error("TODO: TypeRef::match_args on ASSOCIATED");
+ }
+}
+
bool TypeRef::is_concrete() const
{
switch(m_class)
diff --git a/src/types.hpp b/src/types.hpp
index 2c3a2056..27d74662 100644
--- a/src/types.hpp
+++ b/src/types.hpp
@@ -123,6 +123,10 @@ public:
/// Merge with another type (combines known aspects, conflitcs cause an exception)
void merge_with(const TypeRef& other);
+ /// Replace 'GENERIC' entries with the return value of the closure
+ void resolve_args(::std::function<TypeRef(const char*)> fcn);
+ /// Match 'GENERIC' entries with another type, passing matches to a closure
+ void match_args(const TypeRef& other, ::std::function<void(const char*,const TypeRef&)> fcn) const;
/// Returns true if the type is fully known (all sub-types are not wildcards)
bool is_concrete() const;
@@ -131,6 +135,7 @@ public:
bool is_unit() const { return m_class == UNIT; }
bool is_path() const { return m_class == PATH; }
bool is_type_param() const { return m_class == GENERIC; }
+ bool is_reference() const { return m_class == REFERENCE; }
const ::std::string& type_param() const { assert(is_type_param()); return m_path[0].name(); }
AST::Path& path() { assert(is_path() || m_class == ASSOCIATED); return m_path; }
::std::vector<TypeRef>& sub_types() { return m_inner_types; }