1. ホーム
  2. java

4,000,000,000個の数字のうち、最も頻度の高い100個の数字を求めるにはどうしたらいいですか?

2023-07-24 10:24:43

質問

昨日のコーディングの面接で、例えば4,000,000,000個の整数(重複があるかもしれません)から最も頻度の高い100個の数字を得る方法を尋ねられました。

813972066
908187460
365175040
120428932
908187460
504108776

最初に思いついたのはHashMapを使った方法でした。

static void printMostFrequent100Numbers() throws FileNotFoundException {
    
    // Group unique numbers, key=number, value=frequency
    Map<String, Integer> unsorted = new HashMap<>();
    try (Scanner scanner = new Scanner(new File("numbers.txt"))) {
        while (scanner.hasNextLine()) {
            String number = scanner.nextLine();
            unsorted.put(number, unsorted.getOrDefault(number, 0) + 1);
        }
    }

    // Sort by frequency in descending order
    List<Map.Entry<String, Integer>> sorted = new LinkedList<>(unsorted.entrySet());
    sorted.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));

    // Print first 100 numbers
    int count = 0;
    for (Map.Entry<String, Integer> entry : sorted) {
        System.out.println(entry.getKey());
        if (++count == 100) {
            return;
        }
    }
}

しかし、4,000,000,000 個の数字のデータセットでは、おそらく OutOfMemory 例外を投げるでしょう。しかも、4,000,000,000はJava配列の最大長を超えるので、たとえばテキストファイルで、しかもソートされていない数字があるとします。ビッグデータには、マルチスレッドやMap Reduceが適しているのではないでしょうか?

データが利用可能なメモリに収まらない場合、上位100個の値はどのように計算されるのでしょうか?

どのように解決するのですか?

もし、データが ソートされている であれば、上位100位までを O(n) ここで n はデータの大きさである。データはソートされているので、異なる値は連続している。データを1回走査する間にそれらを数えると グローバル 頻度、これはデータがソートされていないときには利用できません。

これをどのように行うかについては、以下のサンプルコードを参照してください。また、このアプローチ全体の(Kotlinによる)実装が GitHub

注意 ソートは必須ではありません。何 必要なのは、異なる値が連続していることで、順序を定義する必要はありません。

データファイルを(外部の)マージソートを使っておおまかにソートすることができます。 O(n log n) で、入力データファイルをメモリに収まる程度の小さなファイルに分割し、ソートされたファイルに書き出してからマージすることで可能です。



このコードサンプルについて。

  • ソートされたデータは long[] . ロジックは値を1つずつ読み取るので、ソートされたファイルからデータを読み取るのと同じ近似値でOKです。

  • OP は、同じ頻度の複数の値がどのように扱われるべきかを指定していません。その結果、コードは、結果が順不同の上位 N 個の値であり、同じ頻度の他の値がないことを暗示しないことを保証する以上のことは何もしません。

import java.util.*;
import java.util.Map.Entry;

class TopN {
    private final int maxSize;
    private Map<Long, Long> countMap;

    public TopN(int maxSize) {
        this.maxSize = maxSize;
        this.countMap = new HashMap(maxSize);
    }

    private void addOrReplace(long value, long count) {
        if (countMap.size() < maxSize) {
            countMap.put(value, count);
        } else {
            Optional<Entry<Long, Long>> opt = countMap.entrySet().stream().min(Entry.comparingByValue());
            Entry<Long, Long> minEntry = opt.get();
            if (minEntry.getValue() < count) {
                countMap.remove(minEntry.getKey());
                countMap.put(value, count);
            }
        }
    }

    public Set<Long> get() {
        return countMap.keySet();
    }

    public void process(long[] data) {
        long value = data[0];
        long count = 0;

        for (long current : data) {
            if (current == value) {
                ++count;
            } else {
                addOrReplace(value, count);
                value = current;
                count = 1;
            }
        }
        addOrReplace(value, count);
    }

    public static void main(String[] args) {
        long[] data = {0, 2, 3, 3, 4, 5, 5, 5, 5, 6, 6, 6, 7};
        TopN topMap = new TopN(2);

        topMap.process(data);
        System.out.println(topMap.get()); // [5, 6]
    }
}