diff options
| -rw-r--r-- | samples/std.rs | 4 | ||||
| -rw-r--r-- | src/ast/ast.cpp | 52 | ||||
| -rw-r--r-- | src/ast/ast.hpp | 8 | ||||
| -rw-r--r-- | src/ast/expr.cpp | 14 | ||||
| -rw-r--r-- | src/ast/expr.hpp | 18 | ||||
| -rw-r--r-- | src/ast/path.cpp | 2 | ||||
| -rw-r--r-- | src/ast/path.hpp | 3 | ||||
| -rw-r--r-- | src/common.hpp | 16 | ||||
| -rw-r--r-- | src/convert/ast_iterate.cpp | 8 | ||||
| -rw-r--r-- | src/convert/typecheck_expr.cpp | 191 | ||||
| -rw-r--r-- | src/convert/typecheck_params.cpp | 8 | ||||
| -rw-r--r-- | src/types.cpp | 84 | ||||
| -rw-r--r-- | src/types.hpp | 5 | 
13 files changed, 396 insertions, 17 deletions
| diff --git a/samples/std.rs b/samples/std.rs index 4f152fc6..a0110c40 100644 --- a/samples/std.rs +++ b/samples/std.rs @@ -45,9 +45,9 @@ pub mod iter  pub mod char  { -    pub fn from_u32(v: u32) -> char +    pub fn from_u32(v: u32) -> ::option::Option<char>      { -        v   // TODO: This should generate a typecheck failure, but that part is incomplete +        ::option::Option::Some(v as char)          // Will eventually need a version of mem::transmute()      }  } diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 21a8b58a..7fb84100 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -104,6 +104,12 @@ const Module& Crate::get_root_module(const ::std::string& name) const {          return it->second.root_module();
      throw ParseError::Generic("crate name unknown");
  }
 +
 +Function& Crate::lookup_method(const TypeRef& type, const char *name)
 +{
 +    throw ParseError::Generic( FMT("TODO: Lookup method "<<name<<" for type " <<type));
 +}
 +
  void Crate::load_extern_crate(::std::string name)
  {
      ::std::ifstream is("output/"+name+".ast");
 @@ -281,6 +287,42 @@ SERIALISE_TYPE(Enum::, "AST_Enum", {      s.item(m_variants);
  })
 +TypeRef Struct::get_field_type(const char *name, const ::std::vector<TypeRef>& args)
 +{
 +    if( args.size() != m_params.n_params() )
 +    {
 +        throw ::std::runtime_error("Incorrect parameter count for struct");
 +    }
 +    // TODO: Should the bounds be checked here? Or is the count sufficient?
 +    for(const auto& f : m_fields)
 +    {
 +        if( f.first == name )
 +        {
 +            // Found it!
 +            if( args.size() )
 +            {
 +                TypeRef res = f.second;
 +                res.resolve_args( [&](const char *argname){
 +                    for(unsigned int i = 0; i < m_params.n_params(); i ++)
 +                    {
 +                        if( m_params.params()[i].name() == argname ) {
 +                            return args.at(i);
 +                        }
 +                    }
 +                    throw ::std::runtime_error("BUGCHECK - Unknown arg in field type");
 +                    });
 +                return res;
 +            }
 +            else
 +            {
 +                return f.second;
 +            }
 +        }
 +    }
 +    
 +    throw ::std::runtime_error(FMT("No such field " << name));
 +}
 +
  SERIALISE_TYPE(Struct::, "AST_Struct", {
      s << m_params;
      s << m_fields;
 @@ -336,6 +378,16 @@ SERIALISE_TYPE_S(GenericBound, {      s.item(m_trait);
  })
 +int TypeParams::find_name(const char* name) const
 +{
 +    for( unsigned int i = 0; i < m_params.size(); i ++ )
 +    {
 +        if( m_params[i].name() == name )
 +            return i;
 +    }
 +    return -1;
 +}
 +
  ::std::ostream& operator<<(::std::ostream& os, const TypeParams& tps)
  {
      //return os << "TypeParams({" << tps.m_params << "}, {" << tps.m_bounds << "})";
 diff --git a/src/ast/ast.hpp b/src/ast/ast.hpp index b4ccf582..b56bd26a 100644 --- a/src/ast/ast.hpp +++ b/src/ast/ast.hpp @@ -98,6 +98,8 @@ public:          m_bounds.push_back( ::std::move(bound) );
      }
 +    int find_name(const char* name) const;
 +    
      friend ::std::ostream& operator<<(::std::ostream& os, const TypeParams& tp);
      SERIALISABLE_PROTOTYPES();
  };
 @@ -243,6 +245,8 @@ public:      {
      }
 +    const Class fcn_class() const { return m_fcn_class; }
 +    
      TypeParams& params() { return m_params; }
      Expr& code() { return m_code; }
      TypeRef& rettype() { return m_rettype; }
 @@ -326,6 +330,8 @@ public:      TypeParams& params() { return m_params; }
      ::std::vector<StructItem>& fields() { return m_fields; }
 +    TypeRef get_field_type(const char *name, const ::std::vector<TypeRef>& args);
 +    
      SERIALISABLE_PROTOTYPES();
  };
 @@ -493,6 +499,8 @@ public:      ::std::map< ::std::string, ExternCrate>& extern_crates() { return m_extern_crates; }   
      const ::std::map< ::std::string, ExternCrate>& extern_crates() const { return m_extern_crates; }   
 +    Function& lookup_method(const TypeRef& type, const char *name);
 +    
      void load_extern_crate(::std::string name);
      void iterate_functions( fcn_visitor_t* visitor );
 diff --git a/src/ast/expr.cpp b/src/ast/expr.cpp index 026a54c1..1009709b 100644 --- a/src/ast/expr.cpp +++ b/src/ast/expr.cpp @@ -44,6 +44,8 @@ SERIALISE_TYPE(Expr::, "Expr", {      else _(ExprNode_Tuple)      else _(ExprNode_NamedValue)      else _(ExprNode_Field) +    else _(ExprNode_Deref) +    else _(ExprNode_Cast)      else _(ExprNode_CallPath)      else _(ExprNode_BinOp)      else @@ -178,6 +180,13 @@ SERIALISE_TYPE_S(ExprNode_Field, {      s.item(m_name);  }) +void ExprNode_Deref::visit(NodeVisitor& nv) { +    nv.visit(*this); +} +SERIALISE_TYPE_S(ExprNode_Deref, { +    s.item(m_value); +}); +  void ExprNode_Cast::visit(NodeVisitor& nv) {      nv.visit(*this);  } @@ -325,6 +334,11 @@ void NodeVisitor::visit(ExprNode_Field& node)      DEBUG("DEF - ExprNode_Field");      visit(node.m_obj);  } +void NodeVisitor::visit(ExprNode_Deref& node)  +{ +    DEBUG("DEF - ExprNode_Deref"); +    visit(node.m_value); +}  void NodeVisitor::visit(ExprNode_Cast& node)   {      DEBUG("DEF - ExprNode_Cast"); diff --git a/src/ast/expr.hpp b/src/ast/expr.hpp index ad2f906d..f1d16af0 100644 --- a/src/ast/expr.hpp +++ b/src/ast/expr.hpp @@ -293,6 +293,23 @@ struct ExprNode_Field:      SERIALISABLE_PROTOTYPES();  }; +// Pointer dereference +struct ExprNode_Deref: +    public ExprNode +{ +    ::std::unique_ptr<ExprNode>    m_value; +     +    ExprNode_Deref() {} +    ExprNode_Deref(::std::unique_ptr<ExprNode> value): +        m_value( ::std::move(value) ) +    { +    } +     +    virtual void visit(NodeVisitor& nv) override; + +    SERIALISABLE_PROTOTYPES(); +}; +  // Type cast ('as')  struct ExprNode_Cast:      public ExprNode @@ -369,6 +386,7 @@ public:      virtual void visit(ExprNode_NamedValue& node);      virtual void visit(ExprNode_Field& node); +    virtual void visit(ExprNode_Deref& node);      virtual void visit(ExprNode_Cast& node);      virtual void visit(ExprNode_BinOp& node);  }; diff --git a/src/ast/path.cpp b/src/ast/path.cpp index 6d5e50b2..6daaf406 100644 --- a/src/ast/path.cpp +++ b/src/ast/path.cpp @@ -126,7 +126,7 @@ void Path::resolve(const Crate& root_crate)                  // - Maybe leave that up to other code?                  if( is_last ) {                      m_binding_type = ALIAS; -                    m_binding.alias = &it->data; +                    m_binding.alias_ = &it->data;                      return ;                  }                  else { diff --git a/src/ast/path.hpp b/src/ast/path.hpp index 4576ceec..5325f8ef 100644 --- a/src/ast/path.hpp +++ b/src/ast/path.hpp @@ -89,7 +89,7 @@ private:              const Enum* enum_;              unsigned int idx;          } enumvar; -        const TypeAlias*    alias; +        const TypeAlias*    alias_;      } m_binding;  public:      Path(): @@ -163,6 +163,7 @@ public:      //_(Enum,   enum,   ENUM)      _(Function, func, FUNCTION)      _(Static, static, STATIC) +    _(TypeAlias, alias, ALIAS)      #undef _      const Enum& bound_enum() const {          assert(m_binding_type == ENUM || m_binding_type == ENUM_VAR);  // Kinda evil, given that it has its own union entry diff --git a/src/common.hpp b/src/common.hpp index a6f0717c..b9b90270 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -84,6 +84,22 @@ option<T> None() {  namespace std {  template <typename T> +inline ::std::ostream& operator<<(::std::ostream& os, const ::std::vector<T*>& v) { +    if( v.size() > 0 ) +    { +        bool is_first = true; +        for( const auto& i : v ) +        { +            if(!is_first) +                os << ", "; +            is_first = false; +            os << *i; +        } +    } +    return os; +} + +template <typename T>  inline ::std::ostream& operator<<(::std::ostream& os, const ::std::vector<T>& v) {      if( v.size() > 0 )      { diff --git a/src/convert/ast_iterate.cpp b/src/convert/ast_iterate.cpp index 4ed63b56..1c232b33 100644 --- a/src/convert/ast_iterate.cpp +++ b/src/convert/ast_iterate.cpp @@ -84,7 +84,10 @@ void CASTIterator::handle_pattern(AST::Pattern& pat, const TypeRef& type_hint)      if( pat.binding().size() > 0 )      {          // TODO: Mutable bindings -        local_variable( false, pat.binding(), type_hint ); +        if(pat.binding() != "_") +        { +            local_variable( false, pat.binding(), type_hint ); +        }      }      for( auto& subpat : pat.sub_patterns() )          handle_pattern(subpat, (const TypeRef&)TypeRef()); @@ -205,6 +208,9 @@ void CASTIterator::handle_trait(AST::Path path, AST::Trait& trait)  {      start_scope();      handle_params( trait.params() ); +     +    local_type("Self", TypeRef(path)); +          for( auto& fcn : trait.functions() )          handle_function( path + fcn.name, fcn.data );      end_scope(); diff --git a/src/convert/typecheck_expr.cpp b/src/convert/typecheck_expr.cpp index 44d3056b..21ad40b8 100644 --- a/src/convert/typecheck_expr.cpp +++ b/src/convert/typecheck_expr.cpp @@ -13,16 +13,20 @@ class CTypeChecker:      struct Scope {          ::std::vector< ::std::tuple<bool, ::std::string, TypeRef> >   vars; +        ::std::vector< ::std::tuple< ::std::string, TypeRef> >  types;      }; +    AST::Crate& m_crate;      ::std::vector<Scope>    m_scopes; -protected: -    TypeRef& get_local(const char* name); -    void lookup_method(const TypeRef& type, const char* name);  public: +    CTypeChecker(AST::Crate& crate): +        m_crate(crate) +    {} +          virtual void start_scope() override;      virtual void local_variable(bool is_mut, ::std::string name, const TypeRef& type) override; +    virtual void local_type(::std::string name, TypeRef type) override;      virtual void end_scope() override;      virtual void handle_function(AST::Path path, AST::Function& fcn) override; @@ -30,6 +34,11 @@ public:      virtual void handle_enum(AST::Path path, AST::Enum& ) override {}      virtual void handle_struct(AST::Path path, AST::Struct& str) override {}      virtual void handle_alias(AST::Path path, AST::TypeAlias& ) override {} + +private: +    TypeRef& get_local_var(const char* name); +    const TypeRef& get_local_type(const char* name); +    void lookup_method(const TypeRef& type, const char* name);  };  class CTC_NodeVisitor:      public AST::NodeVisitor @@ -47,6 +56,9 @@ public:      virtual void visit(AST::ExprNode_Match& node) override; +    virtual void visit(AST::ExprNode_Field& node) override; +    virtual void visit(AST::ExprNode_Cast& node) override; +          virtual void visit(AST::ExprNode_CallMethod& node) override;      virtual void visit(AST::ExprNode_CallPath& node) override;  }; @@ -57,14 +69,20 @@ void CTypeChecker::start_scope()  }  void CTypeChecker::local_variable(bool is_mut, ::std::string name, const TypeRef& type)   { +    DEBUG("is_mut=" << is_mut << " name=" << name << " type=" << type);      m_scopes.back().vars.push_back( make_tuple(is_mut, name, TypeRef(type)) );  } +void CTypeChecker::local_type(::std::string name, TypeRef type) +{ +    DEBUG("name=" << name << " type=" << type); +    m_scopes.back().types.push_back( make_tuple(name, ::std::move(type)) ); +}  void CTypeChecker::end_scope()   {      m_scopes.pop_back();  } -TypeRef& CTypeChecker::get_local(const char* name) +TypeRef& CTypeChecker::get_local_var(const char* name)  {      for( auto it = m_scopes.end(); it-- != m_scopes.begin(); )      { @@ -76,7 +94,21 @@ TypeRef& CTypeChecker::get_local(const char* name)              }          }      } -    throw ::std::runtime_error(FMT("get_local - name " << name << " not found")); +    throw ::std::runtime_error(FMT("get_local_type - name " << name << " not found")); +} +const TypeRef& CTypeChecker::get_local_type(const char* name) +{ +    for( auto it = m_scopes.end(); it-- != m_scopes.begin(); ) +    { +        for( auto it2 = it->types.end(); it2-- != it->types.begin(); ) +        { +            if( name == ::std::get<0>(*it2) ) +            { +                return ::std::get<1>(*it2); +            } +        } +    } +    throw ::std::runtime_error(FMT("get_local_type - name " << name << " not found"));  }  void CTypeChecker::handle_function(AST::Path path, AST::Function& fcn) @@ -88,6 +120,21 @@ void CTypeChecker::handle_function(AST::Path path, AST::Function& fcn)      handle_type(fcn.rettype()); +    switch(fcn.fcn_class()) +    { +    case AST::Function::CLASS_UNBOUND: +        break; +    case AST::Function::CLASS_REFMETHOD: +        local_variable(false, "self", TypeRef(TypeRef::TagReference(), false, get_local_type("Self"))); +        break; +    case AST::Function::CLASS_MUTMETHOD: +        local_variable(false, "self", TypeRef(TypeRef::TagReference(), true, get_local_type("Self"))); +        break; +    case AST::Function::CLASS_VALMETHOD: +        local_variable(true, "self", TypeRef(get_local_type("Self"))); +        break; +    } +          for( auto& arg : fcn.args() )      {          handle_type(arg.second); @@ -157,8 +204,9 @@ void CTC_NodeVisitor::visit(AST::ExprNode_NamedValue& node)      }      else      { -        TypeRef& local_type = m_tc.get_local( p[0].name().c_str() ); +        TypeRef& local_type = m_tc.get_local_var( p[0].name().c_str() );          node.get_res_type().merge_with( local_type ); +        DEBUG("res type = " << node.get_res_type());          local_type = node.get_res_type();      }  } @@ -215,6 +263,62 @@ void CTC_NodeVisitor::visit(AST::ExprNode_Match& node)      }  } +void CTC_NodeVisitor::visit(AST::ExprNode_Field& node) +{ +    DEBUG("ExprNode_Field " << node.m_name); +     +    AST::NodeVisitor::visit(node.m_obj); +     +    TypeRef* tr = &node.m_obj->get_res_type(); +    DEBUG("ExprNode_Field - tr = " << *tr); +    if( tr->is_concrete() ) +    { +        // Must be a structure type (what about associated items?) +        unsigned int deref_count = 0; +        while( tr->is_reference() ) +        { +            tr = &tr->sub_types()[0]; +            DEBUG("ExprNode_Field - ref deref to " << *tr); +            deref_count ++; +        } +        if( !tr->is_path() ) +        { +            throw ::std::runtime_error("ExprNode_Field - Type not a path"); +        } +         +        // TODO Move this logic to types.cpp? +        const AST::Path& p = tr->path(); +        switch( p.binding_type() ) +        { +        case AST::Path::STRUCT: { +            const AST::PathNode& lastnode = p.nodes().back(); +            AST::Struct& s = const_cast<AST::Struct&>( p.bound_struct() ); +            node.get_res_type().merge_with( s.get_field_type(node.m_name.c_str(), lastnode.args()) ); +            break; } +        default: +            throw ::std::runtime_error("TODO: Get field from non-structure"); +        } +        DEBUG("deref_count = " << deref_count); +        for( unsigned i = 0; i < deref_count; i ++ ) +        { +            node.m_obj = ::std::unique_ptr<AST::ExprNode>(new AST::ExprNode_Deref( ::std::move(node.m_obj) )); +        } +    } +    else +    { +        DEBUG("ExprNode_Field - Type not concrete, can't get field"); +    } +} + +void CTC_NodeVisitor::visit(AST::ExprNode_Cast& node) +{ +    DEBUG("ExprNode_Cast " << node.m_type); +     +    AST::NodeVisitor::visit(node.m_value); + +    node.get_res_type().merge_with( node.m_type ); +} +  void CTC_NodeVisitor::visit(AST::ExprNode_CallMethod& node)  {      DEBUG("ExprNode_CallMethod " << node.m_method); @@ -228,32 +332,95 @@ void CTC_NodeVisitor::visit(AST::ExprNode_CallMethod& node)      // Locate method      const TypeRef& type = node.m_val->get_res_type(); -    DEBUG("- type = " << type); +    DEBUG("CallMethod - type = " << type);      if( type.is_wildcard() )      {          // No idea (yet)          // - TODO: Support case where a trait is known +        throw ::std::runtime_error("Unknown type in CallMethod");      }      else      {          // - Search for a method on this type -        //const Function& fcn = type.lookup_method(node.m_method.name()); +        AST::Function& fcn = m_tc.m_crate.lookup_method(type, node.m_method.name().c_str()); +        if( fcn.params().n_params() != node.m_method.args().size() ) +        { +            throw ::std::runtime_error("TODO: CallMethod with param count mismatch"); +        } +        if( fcn.params().n_params() ) +        { +            throw ::std::runtime_error("TODO: CallMethod with params"); +        } +        node.get_res_type().merge_with( fcn.rettype() );      }  }  void CTC_NodeVisitor::visit(AST::ExprNode_CallPath& node)  {      DEBUG("ExprNode_CallPath - " << node.m_path); +    ::std::vector<TypeRef> argtypes;      for( auto& arg : node.m_args )      {          AST::NodeVisitor::visit(arg); +        argtypes.push_back( arg->get_res_type() );      } -    if(node.m_path.binding_type() == AST::Path::FUNCTION) { +    if(node.m_path.binding_type() == AST::Path::FUNCTION) +    {          const AST::Function& fcn = node.m_path.bound_func(); +         +        if( fcn.params().n_params() > 0 ) +        { +            throw ::std::runtime_error("CallPath - TODO: Params on functions"); +        } +         +        DEBUG("ExprNode_CallPath - rt = " << fcn.rettype()); +        node.get_res_type().merge_with( fcn.rettype() );      } -    else if(node.m_path.binding_type() == AST::Path::ENUM_VAR) { -        const AST::Enum& emn = node.m_path.bound_enum(); +    else if(node.m_path.binding_type() == AST::Path::ENUM_VAR) +    { +        const AST::Enum& enm = node.m_path.bound_enum(); +        const unsigned int idx = node.m_path.bound_idx(); +        const auto& var = enm.variants().at(idx); +         +        const auto& params = enm.params();          // We know the enum, but it might have type params, need to handle that case +         +        if( params.n_params() > 0 ) +        { +            // 1. Obtain the pattern set from the path (should it be pre-marked with _ types?) +            auto& path_args = node.m_path[node.m_path.size()-2].args(); +            while( path_args.size() < params.n_params() ) +                path_args.push_back( TypeRef() ); +            DEBUG("path_args = [" << path_args << "]"); +            // 2. Create a pattern from the argument types and the format of the variant +            DEBUG("argtypes = [" << argtypes << "]"); +            ::std::vector<TypeRef>  item_args(enm.params().n_params()); +            DEBUG("variant type = " << var.second << ""); +            var.second.match_args( +                TypeRef(TypeRef::TagTuple(), argtypes), +                [&](const char *name, const TypeRef& t) { +                    DEBUG("Binding " << name << " to type " << t); +                    int idx = params.find_name(name); +                    if( idx == -1 ) { +                        throw ::std::runtime_error(FMT("Can't find generic " << name)); +                    } +                    item_args.at(idx).merge_with( t ); +                }); +            DEBUG("item_args = [" << item_args << "]"); +            // 3. Merge the two sets of arguments +            for( unsigned int i = 0; i < path_args.size(); i ++ ) +            { +                path_args[i].merge_with( item_args[i] ); +            } +            DEBUG("new path_args = [" << path_args << "]"); +        } +     +        AST::Path   p = node.m_path; +        p.nodes().pop_back(); +        TypeRef ty( ::std::move(p) ); +         +        DEBUG("ExprNode_CallPath - enum t = " << ty); +        node.get_res_type().merge_with(ty);      }      else       { @@ -264,7 +431,7 @@ void CTC_NodeVisitor::visit(AST::ExprNode_CallPath& node)  void Typecheck_Expr(AST::Crate& crate)  {      DEBUG(" >>>"); -    CTypeChecker    tc; +    CTypeChecker    tc(crate);      tc.handle_module(AST::Path({}), crate.root_module());      DEBUG(" <<<");  } diff --git a/src/convert/typecheck_params.cpp b/src/convert/typecheck_params.cpp index 8d757c8c..7a682ce7 100644 --- a/src/convert/typecheck_params.cpp +++ b/src/convert/typecheck_params.cpp @@ -1,4 +1,5 @@  /* +/// Typecheck generic parameters (ensure that they match all generic bounds)   */  #include <main_bindings.hpp>  #include "ast_iterate.hpp" @@ -180,6 +181,8 @@ void CGenericParamChecker::handle_path(AST::Path& path, CASTIterator::PathMode p      const AST::TypeParams* params = nullptr;      switch(path.binding_type())      { +    case AST::Path::UNBOUND: +        throw ::std::runtime_error("CGenericParamChecker::handle_path - Unbound path");      case AST::Path::MODULE:          DEBUG("WTF - Module path, isn't this invalid at this stage?");          break; @@ -192,6 +195,9 @@ void CGenericParamChecker::handle_path(AST::Path& path, CASTIterator::PathMode p      case AST::Path::ENUM:          params = &path.bound_enum().params();          if(0) +    case AST::Path::ALIAS: +        params = &path.bound_alias().params(); +        if(0)      case AST::Path::FUNCTION:          params = &path.bound_func().params(); @@ -203,6 +209,8 @@ void CGenericParamChecker::handle_path(AST::Path& path, CASTIterator::PathMode p              throw ::std::runtime_error( FMT("Checking '" << path << "', threw : " << e.what()) );          }          break; +    default: +        throw ::std::runtime_error("Unknown path type in CGenericParamChecker::handle_path");      }  } diff --git a/src/types.cpp b/src/types.cpp index 384c0e6e..7445dd03 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -82,6 +82,90 @@ void TypeRef::merge_with(const TypeRef& other)      }  } +/// Resolve all Generic/Argument types to the value returned by the passed closure +void TypeRef::resolve_args(::std::function<TypeRef(const char*)> fcn) +{ +    switch(m_class) +    { +    case TypeRef::ANY: +        // TODO: Is resolving args on an ANY an erorr? +        break; +    case TypeRef::UNIT: +    case TypeRef::PRIMITIVE: +        break; +    case TypeRef::TUPLE: +    case TypeRef::REFERENCE: +    case TypeRef::POINTER: +    case TypeRef::ARRAY: +        for( auto& t : m_inner_types ) +            t.resolve_args(fcn); +    case TypeRef::GENERIC: +        *this = fcn(m_path[0].name().c_str()); +        break; +    case TypeRef::PATH: +        for(auto& n : m_path.nodes()) +        { +            for(auto& p : n.args()) +                p.resolve_args(fcn); +        } +        break; +    case TypeRef::ASSOCIATED: +        for(auto& t : m_inner_types ) +            t.resolve_args(fcn); +        break; +    } +} + +void TypeRef::match_args(const TypeRef& other, ::std::function<void(const char*,const TypeRef&)> fcn) const +{ +    // If the other type is a wildcard, early return +    // - TODO - Might want to restrict the other type to be of the same form as this type +    if( other.m_class == TypeRef::ANY ) +        return; +    // If this type is a generic, then call the closure with the other type +    if( m_class == TypeRef::GENERIC ) { +        fcn( m_path[0].name().c_str(), other ); +        return ; +    } +     +    // Any other case, it's a "pattern" match +    if( m_class != other.m_class ) +        throw ::std::runtime_error("Type mismatch (class)"); +    switch(m_class) +    { +    case TypeRef::ANY: +        // Wait, isn't this an error? +        throw ::std::runtime_error("Encountered '_' in match_args"); +    case TypeRef::UNIT: +        break; +    case TypeRef::PRIMITIVE: +        // TODO: Should check if the type matches +        if( m_core_type != other.m_core_type ) +            throw ::std::runtime_error("Type mismatch (core)"); +        break; +    case TypeRef::TUPLE: +        if( m_inner_types.size() != other.m_inner_types.size() ) +            throw ::std::runtime_error("Type mismatch (tuple size)"); +        for(unsigned int i = 0; i < m_inner_types.size(); i ++ ) +            m_inner_types[i].match_args( other.m_inner_types[i], fcn ); +        break; +    case TypeRef::REFERENCE: +    case TypeRef::POINTER: +        if( m_is_inner_mutable != other.m_is_inner_mutable ) +            throw ::std::runtime_error("Type mismatch (inner mutable)"); +        m_inner_types[0].match_args( other.m_inner_types[0], fcn ); +        break; +    case TypeRef::ARRAY: +        throw ::std::runtime_error("TODO: TypeRef::match_args on ARRAY"); +    case TypeRef::GENERIC: +        throw ::std::runtime_error("Encountered GENERIC in match_args"); +    case TypeRef::PATH: +        throw ::std::runtime_error("TODO: TypeRef::match_args on PATH"); +    case TypeRef::ASSOCIATED: +        throw ::std::runtime_error("TODO: TypeRef::match_args on ASSOCIATED"); +    } +} +  bool TypeRef::is_concrete() const  {      switch(m_class) diff --git a/src/types.hpp b/src/types.hpp index 2c3a2056..27d74662 100644 --- a/src/types.hpp +++ b/src/types.hpp @@ -123,6 +123,10 @@ public:      /// Merge with another type (combines known aspects, conflitcs cause an exception)
      void merge_with(const TypeRef& other);
 +    /// Replace 'GENERIC' entries with the return value of the closure
 +    void resolve_args(::std::function<TypeRef(const char*)> fcn);
 +    /// Match 'GENERIC' entries with another type, passing matches to a closure
 +    void match_args(const TypeRef& other, ::std::function<void(const char*,const TypeRef&)> fcn) const;
      /// Returns true if the type is fully known (all sub-types are not wildcards)
      bool is_concrete() const;
 @@ -131,6 +135,7 @@ public:      bool is_unit() const { return m_class == UNIT; }
      bool is_path() const { return m_class == PATH; }
      bool is_type_param() const { return m_class == GENERIC; }
 +    bool is_reference() const { return m_class == REFERENCE; }
      const ::std::string& type_param() const { assert(is_type_param()); return m_path[0].name(); }
      AST::Path& path() { assert(is_path() || m_class == ASSOCIATED); return m_path; }
      ::std::vector<TypeRef>& sub_types() { return m_inner_types; }
 | 
