본문 바로가기
백준브실골/기타

백준 7453, 합이 0인 네 정수

by oculis 2023. 4. 5.
728x90

개요

문제 링크
골드 2, Meet in the middle, Bisect
4000×4 행렬에서 각 열 마다 하나씩 원소를 뽑아 합이 0이 되게 만드는 경우의 수


접근

  1. SSP의 Meet in the middle을 쓰는 두 번째 문제, 2295와 매우 유사한 문제. 접근이 아주 비슷하다. SSP를 알고 있다는 가정 하에 작성했으니 SSP와 관련된 내용을 숙지하길 바란다.
  2. 우선 브루트포스부터 생각해보자. 아래와 같이 4중 포문을 쓰면 되고, 여기까지는 쉽게 생각할 수 있다.
     for (a=0~n)
         for (b=0~n)
             for (c=0~n)
                 for (d=0~n)
                     if (A[a]+B[b]+C[c]+D[d]==0)
                         count++
    그런데 시간을 생각해보면, 어 12초나 줬네? 되나? 안된다. 4000^4만큼의 연산을 하는데, 4000^4=2.56e14이므로 간단한 1e9번 연산을 2초 정도에 한다고 해도 14시간이 넘어간다.
  3. 왜 이렇게 까지 시간이 급격히 커지냐 하면 지수의 힘이다. 그래서 우리는 4000이 곱해지는 횟수를 최소화해야 하고, O(N^2)을 두 번 쓰는 MITM (meet in the middle)을 사용하게 된다. A+B+C+D=0인 상황이므로, A+B=-C-D인 상황을 찾으면 되고, 각각은 O(N^2)에 해줄 수 있다.
  4. 기존의 SSP와 유사한 방법이지만, 제약 조건이 두 가지가 있다.
    1. SSP에서는 원소 개수에 제한이 없어 재귀문을 사용해 2^n가지의 선택을 했다. 반면 여기서는 원소가 두 개로 제한되므로 이중포문으로 n^2가지 선택을 한다.
    2. SSP에서는 중복을 허용하지 않는다. 하지만 이 문제는 당연히 중복을 허용하게 된다. 왜냐? A[i]와 B[i]를 같이 선택해도 중복된 선택이 아니기 때문이다.
  5. 2295와의 차이점은 2295는 존재여부만 파악하고, 이 문제는 개수를 파악한다는 것이다. 따라서 map을 이용해 A+B와 C+D의 개수를 세어주고, -A-B=C+D임을 찾으면 되므로, C+D를 저장한 배열에서 -A-B의 개수를 찾아 더해주면 된다.
  6. 방법은 중복원소를 포함하는 Vector 두 개에서 upper bound - lower bound가 최선인 것 같다. -A-B가 존재하는 범위를 구하는 것이다. map을 쓰면 log(N)의 탐색 때문에 시간초과가 나고, 중복 제거를 위해 count vector를 이용해도 시간 개선이 뚜렷하지 않았다. 모든 방법은 아래 코드에 적어뒀다.

Pseudo code

for (i=0~n)
    for (j=0~n)
        AB_vector.push(A[i]+B[j])
        CD_vector.push(C[i]+D[j])
sort(AB_vector), sort(CD_vector)
for (x : AB_vector)
    // -A[i]-B[j]가 존재하는 범위
    answer += upper(CD,-x)-lower(CD,-x)

Source code

// 중복허용 vector -> 4000ms
#include <bits/stdc++.h>
using namespace std;

#define N 4010
#define pb push_back
#define low(x) lower_bound(all(m2),x)
#define upp(x) upper_bound(all(m2),x)
#define all(a) a.begin(),a.end()
using ld=long long;
vector<int> m1,m2;
int n,a[N],b[N],c[N],d[N];
ld ans;

int main() {
    cin.tie(0)->ios::sync_with_stdio(0);
    cin>>n;
    for (int i=0;i<n;i++)
        cin>>a[i]>>b[i]>>c[i]>>d[i];
    for (int i=0;i<n;i++)
        for (int j=0;j<n;j++) {
            m1.pb(a[i]+b[j]);
            m2.pb(c[i]+d[j]);
        }
    sort(all(m1)),sort(all(m2));
    for (auto x:m1)
        ans+=upp(-x)-low(-x);
    cout<<ans;
}
// count vector 이용 -> 4232ms
#include <bits/stdc++.h>
using namespace std;

#define N 4010
#define pb push_back
#define all(a) a.begin(),a.end()
#define low(a,x) lower_bound(all(a),x)-a.begin()
using ld=long long;
vector<ld> t1,t2,m1,m2,c1,c2;
int n,a[N],b[N],c[N],d[N];
ld ans;

int main() {
    cin.tie(0)->ios::sync_with_stdio(0);
    cin>>n;
    for (int i=0;i<n;i++)
        cin>>a[i]>>b[i]>>c[i]>>d[i];
    for (int i=0;i<n;i++)
        for (int j=0;j<n;j++) {
            t1.pb(a[i]+b[j]);
            t2.pb(c[i]+d[j]);
        }
    int m=n*n;
    sort(all(t1)),sort(all(t2));
    m1.pb(t1[0]),m2.pb(t2[0]);
    c1.pb(1),c2.pb(1);

    for (int i=1;i<m;i++) {
        if (m1.back()==t1[i]) c1.back()++;
        else m1.pb(t1[i]),c1.pb(1);
        if (m2.back()==t2[i]) c2.back()++;
        else m2.pb(t2[i]),c2.pb(1);
    }
    for (int i=0;i<m1.size();i++) {
        int j=low(m2,-m1[i]);
        if (m1[i]+m2[j]==0)
            ans+=c1[i]*c2[j];
    }
    cout<<ans;
}
// map 이용 -> TLE
#include <bits/stdc++.h>
using namespace std;

#define N 4010
using ld=long long;
map <int,int> m;
int n,a[N],b[N],c[N],d[N];
ld ans;

int main() {
    cin>>n;
    for (int i=0;i<n;i++)
        cin>>a[i]>>b[i]>>c[i]>>d[i];
    for (int i=0;i<n;i++)
        for (int j=0;j<n;j++)
            m[a[i]+b[j]]++;
    for (int i=0;i<n;i++)
        for (int j=0;j<n;j++)
            ans+=m[-c[i]-d[j]];
    cout<<ans;
}
728x90

댓글