Program Listing for File OverlappingNMIDistance.hpp

Return to documentation for file (include/networkit/community/OverlappingNMIDistance.hpp)

#ifndef NETWORKIT_COMMUNITY_OVERLAPPING_NMI_DISTANCE_HPP_
#define NETWORKIT_COMMUNITY_OVERLAPPING_NMI_DISTANCE_HPP_

#include <unordered_map>

#include <networkit/auxiliary/HashUtils.hpp>
#include <networkit/community/DissimilarityMeasure.hpp>

namespace NetworKit {

class OverlappingNMIDistance final : public DissimilarityMeasure {

public:
    enum Normalization { MIN, GEOMETRIC_MEAN, ARITHMETIC_MEAN, MAX, JOINT_ENTROPY };

private:
    Normalization mNormalization{Normalization::MAX};

public:
    OverlappingNMIDistance() = default;

    explicit OverlappingNMIDistance(Normalization normalization) : mNormalization(normalization) {}

    void setNormalization(Normalization normalization) { mNormalization = normalization; }

    double getDissimilarity(const Graph &G, const Partition &zeta, const Partition &eta) override;
    double getDissimilarity(const Graph &G, const Cover &zeta, const Cover &eta) override;

private:
    struct SizesAndIntersections {
        std::vector<count> sizesX;
        std::vector<count> sizesY;
        std::unordered_map<std::pair<index, index>, count, Aux::PairHash> intersectionSizes;
    };

    static SizesAndIntersections
    calculateClusterAndIntersectionSizes(const Graph &graph, const Cover &X, const Cover &Y);

    static double h(count w, count n);

    static double entropy(count size, count n);

    static double entropy(const std::vector<count> &sizesX, count n);

    static double adjustedConditionalEntropy(count sizeXi, count sizeYj, count intersectionSize,
                                             count n);

    static double conditionalEntropy(
        const std::vector<count> &sizesX, const std::vector<count> &sizesY,
        const std::unordered_map<std::pair<index, index>, count, Aux::PairHash> &intersectionSizes,
        bool invertPairIndices, count n);

    static void clampBelow(double &value, double lowerBound, const char *format,
                           int printPrecision = 20);

    static void clampAbove(double &value, double upperBound, const char *format,
                           int printPrecision = 20);

    static double normalize(OverlappingNMIDistance::Normalization normalization,
                            double mutualInformation, double entropyX, double entropyY);
};

} // namespace NetworKit

#endif // NETWORKIT_COMMUNITY_OVERLAPPING_NMI_DISTANCE_HPP_