본문 바로가기

알고리즘 문제풀이

고속 푸리에 변환(FFT)의 빠른 구현체

PS에서 고속 푸리에 변환(Fast Fourier Transform, FFT)을 쓰는 문제를 이따금씩 볼 수 있다. 컨볼루션은 주파수 영역에서 원소별 곱 연산(elementwise multiplication)으로 바뀌는 특성을 기반으로, FFT를 통해 \(O(n^2)\) 대신 \(O(n \lg n)\) 시간에 컨볼루션이 가능함을 이용하는 문제들이다.

 

FFT는 크게 Cooley-Tukey, Split-Radix, Tangent의 세 가지 방법이 있다. Cooley-Tukey가 가장 간단하지만 비교적 느리고, Tangent가 가장 빠르지만 구현이 가장 복잡하다. 따라서 PS에서는 Cooley-Tukey를 쓰는 것이 일반적이다. 하지만 Cooley-Tukey는 구현 방법에 따라 실행 시간의 편차가 크게 수십 배까지도 날 수 있다. 따라서 효율적인 구현을 외우는 것이 중요한 문제가 된다.

 

이 글의 목적은 빠른 FFT 구현체를 실제 문제에 적용한 사례와 함께 소개하는 것이다. 풀어볼 문제는 BOJ 17104: 골드바흐 파티션 2로, N=1,000,000의 FFT를 0.5초 이내에 해내야 하므로 시간이 빠듯하다.

 

Cooley-Tukey FFT의 원리 및 기본적인 구현 방법은 blog.myungwoo.kr/54를 참고할 수 있다.

  • 비재귀로 구현하기 위해, 먼저 순서를 적절히 바꾼다. i번째 원소는 i의 비트 표현을 거꾸로 뒤집은 수 j와 바뀐다.
  • 이제 블록 크기를 2, 4, 8, ..., \(2^K\)까지 순서대로 바꾸면서 Cooley-Tukey FFT의 점화식을 처리한다.
  • cos(), sin()을 매번 다시 계산하는 것은 매우 느리므로, \(\omega\)를 계속 곱하는 식으로 처리해야 한다. 이때 실수 오차 누적이 생기기는 하지만 결과 처리에 크게 문제되지는 않는다. 대신 \(10^{10}\)이 넘는 수를 다룰 때는 이 차이가 중요해질 수 있다. 그런 경우에는 매번 cos(), sin()을 호출하는 방법밖에 없다.

아래 구현은 위 블로그를 참고하여 새로 구현한 것이다. 다른 점은 초기 비트플립을 수행하는 부분이다. 프로파일링 결과 비트플립 과정이 전체 수행시간의 20% 가량을 차지하는 것으로 나와서 해당 부분을 Lookup Table (LUT) 기반으로 바꿔준 결과, 시간이 20% 정도 감소한 결과를 얻을 수 있었다.

 

보통의 문제에서는 이 정도 차이가 크지 않지만, BOJ 17104: 골드바흐 파티션 2에서는 Lookup Table 기반 비트플립을 적용해야만 통과할 수 있다. (물론, SSE/AVX 계열 가속을 적용할 수도 있지만, 이 방법은 그보다 간단하면서도 PS에 사용하기에 충분할 만큼 빠르다.)

 

아래 코드는 BOJ 17104: 골드바흐 파티션 2를 FFT로 풀이하는 코드이다. 단, N이 \(2^{24}\)를 넘어가는 경우 비트플립 부분을 수정해 사용해야 한다.

 

#define _CRT_SECURE_NO_WARNINGS
#define _USE_MATH_DEFINES
#include <stdio.h>
#include <vector>
#include <math.h>
#include <memory.h>
#include <complex>

typedef std::complex<double> cd;

void fft_inplace(cd* z, int K, bool fwd) {
    const int M = 1 << K;

    /* prepare bitflip */
    static uint8_t bitflip_lut[256];
    static bool lut_init = false;
    if (!lut_init) {
        for (int i = 0; i < 256; i++) {
            uint8_t j = 0;
            for (int k = 0; k < 8; k++) j |= ((i >> k) & 1) << (7 - k);
            bitflip_lut[i] = j;
        }
        lut_init = true;
    }
    const int J = 24 - K;
    auto bitflip = [&](int b) -> int {
        int c = (bitflip_lut[b & 0xff] << 16) |
            (bitflip_lut[(b >> 8) & 0xff] << 8) |
            (bitflip_lut[(b >> 16) & 0xff]);
        return (c >> J);
    };

    /* permute in advance */
    for (int i = 0; i < M; i++) {
        int j = bitflip(i);
        if (i < j) std::swap(z[i], z[j]);
    }

    /* run FFT */
    const double head = (fwd) ? -2 * M_PI : +2 * M_PI;
    for (int k = 1; k <= K; k++) {
        const int P = (1 << k);
        const int Q = (1 << (k - 1));
        const double theta = head / P;
        const cd t = cd(cos(theta), sin(theta));
        for (int i = 0; i < M; i += P) {
            cd w = 1;
            for (int j = 0; j < Q; j++) {
                /* deal with z[i+j] and z[i+j+Q] */
                cd p = z[i + j];
                cd q = z[i + j + Q];
                cd r = w * q;
                z[i + j] = p + r;
                z[i + j + Q] = p - r;

                /* update w */
                w *= t;
            }
        }
    }

    if (!fwd) {
        for (int i = 0; i < (1 << K); i++) {
            z[i] = z[i] / (double)M;
        }
    }
}

void fft_conv(cd* a, cd* b, int K) {
    fft_inplace(a, K, true);
    fft_inplace(b, K, true);
    for (int i = 0; i < (1 << K); i++) a[i] *= b[i];
    fft_inplace(a, K, false);
}

int T, Q[100001];
int A[1 << 20], Z[1 << 19];
cd E[1 << 20], F[1 << 20];

int main() {
    const int K = 19;

    /* read input */
    scanf("%d", &T);
    for (int i = 1; i <= T; i++) scanf("%d", &Q[i]);

    /* prepare */
    std::vector<int> primes;
    primes.push_back(2);
    for (int x = 3; x <= (1 << (K + 1)); x++) {
        bool composite = false;
        for (auto p : primes) {
            if (p * p > x) break;
            if (x % p == 0) {
                composite = true;
                break;
            }
        }
        if (!composite) primes.push_back(x);
    }

    memset(A, 0, sizeof(A));
    for (auto x : primes) {
        if (x != 2) A[(x - 1) / 2] = 1;
    }

    memset(Z, 0, sizeof(Z));
    for (int i = 0; i < (1 << (K + 1)); i++) E[i] = F[i] = A[i];
    fft_conv(E, F, K + 1);

    for (int i = 0; i < (1 << K); i++) {
        Z[i] = (int)std::round(E[i].real());

        /* deduplication */
        if (i % 2 == 0 && A[i / 2]) Z[i]++;
        Z[i] /= 2;
    }
    Z[1] = 1; // 4 = 2 + 2

    /* print */
    for (int i = 1; i <= T; i++) {
        int N = Q[i];
        printf("%d\n", Z[(N / 2) - 1]);
    }
    return 0;
}