C++从零实现一个Variant.

MyVariant

从零实现一个C++Variant.

本文内容参考博主双笙子佯谬

问题场景

我们知道,variant是支持存储多种不同指定类型的的容器。

比如对于一个variant<int,double,std::sttring>,我们可以赋值int,double,string类型。而这些类型个数是不确定的,所以我们一定需要变参模板。

所以一个基本的框架形成了

template<typename...Ts>
struct Variant{
public:
    template<typename T>
    Variant(T value){}
private:
    ??? value;
};

这里我们的value类型是一个很尴尬的选择,可能是int,可能是std::string,而我们也不可能使用类似T value的操作。因此,我们可以选择union.

union{
    ???;
    ???;
}value;

问题又来了,里面的类型怎么填?你可能会想如下类似:

union{
    Ts...;
}value;

但是很遗憾,union不支持变参模板,直接使用union不成功。因此需要我们手动实现一个类似的union.

union简而言之就是内存重叠。多个类型共用一个内存,内存大小为其中最大的那个类型大小值。

static constexpr size_t max_size()
    {
        size_t max = 0;
        // 使用折叠表达式(C++17)
        ((max = (sizeof(Ts) > max ? sizeof(Ts) : max)), ...);
        return max;
    }
alignas(max_size()) char m_union[max_size()];

这里我们每先计算出所有类型之中的最大值,然后开辟一块内存,同时内存对齐。

之所以需要创建一个max_size函数,是因为std::max在不同的编译器中实现不同,可能不支持编译器操作。

此外为了能够识别不同的类型,我们还需要添加一个标志

std::size_t m_index;

现在的结构就是这样

template<typename...Ts>
struct Variant{
public:
    template<typename T>
    Variant(T value){}
private:
    static constexpr size_t max_size()
    {
        size_t max = 0;
        // 使用折叠表达式(C++17)
        ((max = (sizeof(Ts) > max ? sizeof(Ts) : max)), ...);
        return max;
    }
	alignas(max_size()) char m_union[max_size()];
};

类型的识别

我们使用m_index来作为类型标签,那么一定要有方法能够进行index和类型之间的转换或者说“识别”。

比如Variant<int,double,std::string>,

  • index = 0,T为int
  • index=1,T为double
  • index = 2,T为std::string;

这就需要我们用花编译器的算法

template <typename, typename>
struct VariantIndex { };
template <typename, size_t>
struct VariantAlternative { };


template <typename T, typename... Ts>
struct VariantIndex<Variant<T, Ts...>, T> {
    static constexpr size_t value = 0;
};

template <typename T0, typename T, typename... Ts>
struct VariantIndex<Variant<T0, Ts...>, T> {
    static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};

template <typename T, typename... Ts>
struct VariantAlternative<Variant<T, Ts...>, 0> {
    using type = T;
};

template <typename T, typename... Ts, size_t I>
struct VariantAlternative<Variant<T, Ts...>, I> {
    using type = typename VariantAlternative<Variant<Ts...>, I - 1>::type;
};

先看VariantIndex,注意到这里有三个VariantIndex,第一个内容为空,第二、三个是特化版本。之所以这样是我们想要让对应的类型在特化版本去正确的匹配对应的类型和index,如果出现类型异常,那么就会走到第一个空版本,而空版本中没有type,调用者使用的时候就会报错。这样在编译期就能检查出代码的错误。

template <typename T, typename... Ts>
struct VariantIndex<Variant<T, Ts...>, T> {
    static constexpr size_t value = 0;
};

template <typename T0, typename T, typename... Ts>
struct VariantIndex<Variant<T0, Ts...>, T> {
    static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};

现在来看他们,

template <typename T, typename... Ts>
struct VariantIndex<Variant<T, Ts...>, T> {
    static constexpr size_t value = 0;
};

意思明显,当类型T和Variant的第一个类型匹配了,那么value就是0,也就是Variant的第一个类型对应的index就是0.假如第一个没有匹配呢?就会走到第二个特化:

template <typename T0, typename T, typename... Ts>
struct VariantIndex<Variant<T0, Ts...>, T> {
    static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};

这里的Variant的第一个类型是T0,而我们要判断的类型是T,T0和T不匹配,那我们就去检查剩下的类型即Ts...。

比如Variant<int,double,std::String>,我们使用double进行判断,那么对应的实例化之后就是

template <typename T0(int), typename T(double), typename... Ts(double,std::string)>
struct VariantIndex<Variant<int, Ts...>, double> {
    static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};

这时候就需要再次递归模板调用

static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;

同时还要对value进行+1,因为如果第0个不匹配就判断第一个,第一个也不匹配就判断第二个,因此value+1.

这个明白了那么VariantAlternative也是同理,不过这个是通过index反推类型,第一个特化只要I=0返回第一个类型,否则的话每次I-1去匹配剩下的类型,知道I成为0.

比如Variant<int,double,std::string>,我们使用2进行判断,那么一开始

VariantAlternative<Variant<int,double,std::string>,2>

那么

type = typename VariantAlternative<Variant<double,std::String>, I - 1>::type
    = typename VariantAlternative<Variant<double,std::String>, 1>::type
    = typename VariantAlternative<std::string,0>::type;

最后递归推到到第一个特化,type = T ,也就是std::string.

OK啊,了解这两个重要的函数,那么实现我们的构造函数和get函数就很简单了。

构造函数

#if __cplusplus >= 202002L
    template <typename T>
        requires(std::is_same_v<T, Ts> || ...)
    Variant(T value)
        : m_index(VariantIndex<Variant, T>::value)
    {
        T* p = reinterpret_cast<T*>(m_union);
        new (p) T(value);
    }
#else
    template <typename T, typename std::enable_if<std::disjunction<std::is_same<T, Ts>...>::value, int>::value = 0>
    Variant(T value)
        : m_index(VariantIndex<Variant, T>::value)
    {
        T* p = reinterpret_cast<T*>(m_union);
        new (p) T(value);
    }
#endif

这里提供了两版本的实现,分别是C++20以及以下版本的实现。

C++20版本使用requires约束T必须至少和Ts中的一个类型相同,才能构造,比较简单无需多言。低版本呢则使用std::enable_if判断。

std::disjunction用于判断一组逻辑进行逻辑或的结果,比如std::disjunctionstd::false_type,std::false_type,std::true_tyoe返回true因为有一个true.

这里判断T和Ts的每个类型是否有一个能够相同,是则能够构造。

然后根据类型,赋值签名,然后用到了placement new操作,在已经开辟好的内存进行构造,无需管理内存,只进行构造和析构。

实现get

template <std::size_t I>
    typename VariantAlternative<Variant, I>::type const& get() const
    {
        static_assert(I < sizeof...(Ts), "out of range");

        if (m_index != I) {
            throw BadVariantAccess();
        }

        using _type = typename VariantAlternative<Variant, I>::type;
        return *reinterpret_cast<_type const*>(m_union);
    }


template <typename T>
T const& get() const
{
    return get<VariantIndex<Variant, T>::value>();
}

这里有两个重载函数,第一个函数用来底层的类型判断获取内容,第二个是方便用户调用的get类型的函数。

第二个重载给第一个重载传入的模板参数是通过VariantIndex推导出来的类型签名,将这个index传入第一个重载模板参数。

首先静态判断I是否超出范围,没有的话,在进行m_index和I的比较,也就是签名的比较。签名相同,说明此事存储的类型等于我要get的类型就没有问题,否则就会报自定义的类型错误,BadVariantAccess

struct BadVariantAccess : std::exception {
    BadVariantAccess() = default;
    virtual ~BadVariantAccess() = default;

    const char* what() const noexcept override
    {
        return "BadVariantAccess";
    }
};

如果上面两步都没有错误,那么下面就是将union的存储的内容以get要求的类型返回,Variant将签名对应的类型推导出来,然后将m_union强转成对应的类型再解引用传出去。

using _type = typename VariantAlternative<Variant, I>::type;
return *reinterpret_cast<_type const*>(m_union);

至此,get完成。

析构函数

这个问题比较复杂,构造函数的时候我们提到,因为placement new在已经开辟的内存进行构造/析构,不同的类型占用的空间也不同,析构行为是不一样的。

比如一个int析构和std::string析构,int只占用4字节,而std::string占用32字节效果肯定不一。

因此我们需要为每个不同的类型生成不同的析构函数,这里我们使用了表存储了对应的函数指针。

static void (**get_variant_destructors() noexcept)(char*) noexcept
    {
        static void (*destructors[max_size()])(char*) noexcept = {
            [](char* m_union) noexcept {
                reinterpret_cast<Ts*>(m_union)->~Ts();
            }...
        };
        return destructors;
    }


~Variant() noexcept
    {
        get_variant_destructors()[m_index](m_union);
    }

这个get_variant_destructors很复杂,先看内部,有一个静态函数指针数组,接受一个指针,然后将其强转成不同的类型调用析构函数。然后返回这个数组。

这个数组的形式如下

void (*[])(char*)
/*实际上:
   [ void(*)(char*),void(*)(char*)... ]
   			|
   [ void(*[])(char*)]
*/

实际上是一个‘二维数组’,因此外部函数返回值是一个双重指针。

于是形式如下:

void (**get_variant_destructors())(char*)

这是一个返回指向二维数组指针的函数,这个二维数组或者说这个这个一维指针的每一个元素都是一个函数指针,接受一个char*指针参数,返回void类型。

然后析构函数通过这个函数拿到了存储各种不同类型的表,再根据m_index这个类型签名找到对应的(析构函数)函数指针调用。

其他函数

析构函数使用了静态的函数指针表来存储对应的操作行为,性能高,O(1)复杂度。当你能掌握这里的方法,其他的拷贝构造、赋值,移动构造、赋值。甚至于其他的vist函数等都与此原理趋同。不再赘述。

其他需要注意的问题就是赋值的话,那么说明之前已经构造好了,也就是m_union不为空,那么这时候就是对m_union进行赋值即可。构造的话,则需要开辟m_union内存。

源代码

#include <iostream>
#include <string>
#include <type_traits>

struct BadVariantAccess : std::exception {
    BadVariantAccess() = default;
    virtual ~BadVariantAccess() = default;

    const char* what() const noexcept override
    {
        return "BadVariantAccess";
    }
};

template <std::size_t I>
struct InPlcaeIndex {
    InPlcaeIndex() = default;
};

template <typename, typename>
struct VariantIndex { };
template <typename, size_t>
struct VariantAlternative { };

template <typename... Ts>
struct Variant {
private:
    static constexpr size_t max_size()
    {
        size_t max = 0;
        // 使用折叠表达式(C++17)
        ((max = (sizeof(Ts) > max ? sizeof(Ts) : max)), ...);
        return max;
    }

    alignas(max_size()) char m_union[max_size()];
    std::size_t m_index;

    static void (**get_variant_destructors() noexcept)(char*) noexcept
    {
        static void (*destructors[max_size()])(char*) noexcept = {
            [](char* m_union) noexcept {
                reinterpret_cast<Ts*>(m_union)->~Ts();
            }...
        };
        return destructors;
    }

    static void (**move_constructors() noexcept)(char*, char*) noexcept
    {
        static void (*move_constructors[max_size()])(char*, char*) noexcept = {
            [](char* union_dst, char* union_src) noexcept {
                new (union_dst) Ts(std::move(*reinterpret_cast<Ts*>(union_src)));
            }...
        };
        return move_constructors;
    }
    static void (**move_assignment_constructors() noexcept)(char*, char*) noexcept
    {
        static void (*move_assignment_constructors[max_size()])(char*) noexcept = {
            [](char* union_dst, char* union_src) noexcept {
                *reinterpret_cast<Ts*>(union_dst) = std::move(*reinterpret_cast<Ts*>(union_src));
            }...
        };
        return move_assignment_constructors;
    }

    static void (**copy_constructors() noexcept)(char*, const char*) noexcept
    {
        static void (*copy_constructors[max_size()])(char*, const char*) noexcept = {
            [](char* union_dst, const char* union_src) noexcept {
                new (union_dst) Ts(*reinterpret_cast<Ts const*>(union_src));
            }...
        };
        return copy_constructors;
    }
    static void (**copy_assignment_constructors() noexcept)(char*, const char*) noexcept
    {
        static void (*copy_assignment_constructors[max_size()])(char*, const char*) noexcept = {
            [](char* union_dst, const char* union_src) noexcept {
                *reinterpret_cast<Ts*>(union_dst) = *reinterpret_cast<Ts const*>(union_src);
            }...
        };
        return copy_assignment_constructors;
    }

    template <typename Lambda>
    static typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type (**visitable_c() noexcept)(const char*, Lambda) noexcept
    {
        static void (*visitable[max_size()])(const char*, Lambda) noexcept = {
            [](const char* m_union, Lambda lambda) noexcept -> typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type {
                return lambda(*reinterpret_cast<const Ts*>(m_union));
            }...
        };
        return visitable;
    }

    template <typename Lambda>
    static typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type (**visitable() noexcept)(char*, Lambda) noexcept
    {
        using visit_return_type = typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type;

        static visit_return_type (*visitable[max_size()])(char*, Lambda) noexcept = {
            [](char* m_union, Lambda lambda) noexcept -> visit_return_type {
                return lambda(*reinterpret_cast<Ts*>(m_union));
            }...
        };
        return visitable;
    }

public:
    ~Variant() noexcept
    {
        get_variant_destructors()[m_index](m_union);
    }

    Variant(Variant&& other)
        : m_index(other.m_index)
    {
        move_constructors()[m_index](m_union, other.m_union);
    }
    Variant& operator=(Variant&& other)

    {
        m_index = other.m_index;
        move_assignment_constructors()[m_index](m_union, other.m_union);
    }

    Variant(Variant const& other)
        : m_index(other.m_index)
    {
        copy_constructors()[m_index](m_union, other.m_union);
    }

    Variant& operator=(Variant const& other)
    {
        m_index = other.m_index;
        copy_assignment_constructors()[index()](m_union, other.m_union);
    }

    template <std::size_t I, typename... Args>
    Variant(InPlcaeIndex<I>, Args&&... value_types)
        : m_index(I)
    {
        new (m_union) typename VariantAlternative<Variant, I>::type(std::forward<Args>(value_types)...);
    }

#if __cplusplus >= 202002L
    template <typename T>
        requires(std::is_same_v<T, Ts> || ...)
    Variant(T value)
        : m_index(VariantIndex<Variant, T>::value)
    {
        T* p = reinterpret_cast<T*>(m_union);
        new (p) T(value);
    }
#else
    template <typename T, typename std::enable_if<std::disjunction<std::is_same<T, Ts>...>::value, int>::value = 0>
    Variant(T value)
        : m_index(VariantIndex<Variant, T>::value)
    {
        T* p = reinterpret_cast<T*>(m_union);
        new (p) T(value);
    }
#endif

    constexpr size_t index() const
    {
        return m_index;
    }

    template <typename T>
    constexpr bool holds_alternative() const
    {
        return VariantIndex<Variant, T>::value == m_index;
    }
    template <std::size_t I>
    typename VariantAlternative<Variant, I>::type const& get() const
    {
        static_assert(I < sizeof...(Ts), "out of range");

        if (m_index != I) {
            throw BadVariantAccess();
        }

        using _type = typename VariantAlternative<Variant, I>::type;
        return *reinterpret_cast<_type const*>(m_union);
    }

    template <typename T>
    T const& get() const
    {
        return get<VariantIndex<Variant, T>::value>();
    }

    //
    template <class Lambda>
    typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type visit(Lambda lambda) const
    {
        return visitable_c<Lambda>()[m_index]((m_union), lambda);
    }

    template <class Lambda>
    typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type visit(Lambda lambda)
    {
        using visit_return_type = typename std::common_type<typename std::invoke_result<Lambda, Ts&>::type...>::type;
        return visitable<Lambda>()[m_index]((m_union), lambda);
    }

    template <std::size_t I>
    auto get_if() const -> typename VariantAlternative<Variant, I>::type const*
    {
        static_assert(I < sizeof...(Ts), "I out of range!");
        if (m_index != I) {
            return nullptr;
        }
        return reinterpret_cast<typename VariantAlternative<Variant, I>::type const*>(m_union);
    }

    template <std::size_t I>
    auto get_if() -> typename VariantAlternative<Variant, I>::type*
    {
        static_assert(I < sizeof...(Ts), "I out of range!");
        if (m_index != I) {
            return nullptr;
        }
        return reinterpret_cast<typename VariantAlternative<Variant, I>::type*>(m_union);
    }
};

template <typename T, typename... Ts>
struct VariantIndex<Variant<T, Ts...>, T> {
    static constexpr size_t value = 0;
};

template <typename T0, typename T, typename... Ts>
struct VariantIndex<Variant<T0, Ts...>, T> {
    static constexpr size_t value = VariantIndex<Variant<Ts...>, T>::value + 1;
};

template <typename T, typename... Ts>
struct VariantAlternative<Variant<T, Ts...>, 0> {
    using type = T;
};

template <typename T, typename... Ts, size_t I>
struct VariantAlternative<Variant<T, Ts...>, I> {
    using type = typename VariantAlternative<Variant<Ts...>, I - 1>::type;
};

int main()
{
    Variant<int, double, std::string> var = std::string("SAdsad");
    auto p = var.visit([](auto a) {
        std::cout << a << std::endl;
        return "asd";
    });
    std::cout << p << std::endl;
}
posted @ 2026-03-17 11:22  大胖熊哈  阅读(7)  评论(0)    收藏  举报