개요
문제 링크
플래 3, Bruteforce, Meet in the middle, Bitmasking
위 아래 선택을 절반씩 한 것들의 점수합의 차이가 최소가 되도록 하기
접근
SSP의 Meet in the middle을 쓰는 네 번째 문제, 중간에서 만나기로 머리를 부숴놓는 문제다. 꽤 많이 까다롭고 고민할게 많으니 우선 중간에서 만나기가 무엇인지 SSP 글에서 꼭 숙지하고 읽어보기를 바란다.
우선 브루트포스 방식을 생각해보자. 비트마스킹을 어디에 쓰냐? 하면 두개의 행벡터에서 위를 선택하는 경우 0, 아래를 선택하는 경우를 1이라고 하면 하나의 선택을 36자리 비트로 나타낼 수 있다. 그럼 절반만 아래를 선택해야 하므로 1이 18개인 모든 경우에 대해 따져보면 된다.
엥? next permutation 쓰면 끝나는거 아니냐?
do { ld v=get(c); update(ans,v) } while(next_permutation(all(c)));
대강 이런 식 아닌가 싶은데 이 생각보다 크다. 9e9쯤 된다. 그래서 2초에는 절대 안된다.
그래서 절반을 나눈다. 왼쪽이 1개, 오른쪽이 17개인 (1,17), 2개 16개인 (2,16) ... (17,1) 까지 모두 해줘야 한다. 이걸 빨리, 정확하게 하는게 생각보다 까다로워서 머리가 아프다.
- 구현을 하기 전에 업데이트 조건부터 생각하자. 차이를 구하고 차이가 minimum보다 작으면 무조건 배열도 같이 업데이트 한다.
- 구한 차이가 minimum과 같으면 사전순으로 업데이트 하는데, 배열의 앞에서부터 다른값이 나오면 대소를 비교해 return 한다.
- 그런데 이걸 boolean 벡터를 써주면 참 좋겠지만, 그럼 TLE가 난다. 그래서 비트마스킹을 써줘야 한다. 이게 머리를 깨는 포인트다.
우리가 해줄 것은
- 18개까지 모든 왼쪽/오른쪽 선택의 차이와 비트수를 저장한다.
- N/2까지의 왼쪽 비트수에 대해 N/2-왼쪽 비트수만큼의 비트를 가진 오른쪽 선택에 대해 루프를 돌면서 업데이트 한다.
- 업데이트 할때는 차이가 저장된 최솟값과 다르면 무조건 업데이트 하고, 같으면 최대 비트부터 비교하는데, 이때는 비트연산자 &를 사용한다.
자 1번부터 해주자.
void dp(int i,int e,ld sum,ld bit,int bitcount) { if (i==e) { if (e==n) rm[bitcount].pb({sum,bit}); else lm[bitcount].pb({sum,bit}); return; } dp(i+1,e,sum+a[i],bit,bitcount); dp(i+1,e,sum-b[i],bit+pp[e-1-i],bitcount+1); }
- 일반적인 meet in the middle에서 bitcount와 bit가 추가되었다. 왼쪽과 오른쪽에 대해 사용한 비트수에 따라 18개의 벡터를 만들고, 첫번째 행의 경우 더하고, 두번째 행은 빼준다. 두번째 행을 뺀 경우 비트가 하나 추가되고, bitcount도 하나 추가된다. 추가하는 비트는 미리 배열로 만들어준다.
- 이때 마지막에 bit출력도 해주어야 하므로 vector는 sum과 bit를 pair로 저장한다.
다음 2번이다.
// i=왼쪽 비트수, j=오른쪽 비트수 for (int i=0;i<=n/2;i++) { int j=n/2-i; for (auto p:lm[i]) { // 왼쪽 sum과 차이가 가장 적도록 bisect pll r={-p.first,p.second}; int ll=low(all(rm[j]),r,cmp)-rm[j].begin(); int rr=rm[j].size()-1; int s=max(ll-30,0); int e=min(ll+10,rr); for (int k=s;k<=e;k++) { auto q=rm[j][k]; update(p,q); } } }
- 왼/오 비트 합이 N/2가 되도록 루프를 돌것인데, 2^18=262144이므로 전부 다 돌 수는 없고 (O(N^2)), bisect를 통해 왼쪽합 + 오른합이 최소가 될 후보군에만 돌아준다. 그냥 lower bound에 30칸 왼쪽과 10칸 오른쪽을 루프 돌려주니 되었다.
- lower bound를 보면 cmp 함수가 있는데, 차이가 같으면 bit를 비교하고, 차이가 다르면 차이를 비교한다. bit를 비교하는 방법은 3번에서 다뤄보자.
다음 3번이다.
void update(pll &p,pll &q) { auto &[ls,lbit]=p; auto &[rs,rbit]=q; ld d=abs(ls+rs); // 작으면 무조건 업뎃 if (d<ans) { ans=d; lb=lbit,rb=rbit; // 같으면? 비트 비교 ord 함수 씀 } else if (d==ans) { if (ord(lbit,lb)) lb=lbit,rb=rbit; else if (lb==lbit&&ord(rbit,rb)) lb=lbit,rb=rbit; } }
- 만약 왼차이+오른차이의 절댓값, 즉 전체 차이값이 기존의 최소보다 작으면 무조건 업데이트 한다.
- 같다면 사전 순 비교를 해야 하는데, 아래와 같이 한다.
&연산자를 이용해 큰 자리부터 비교한다. 왼쪽비트에서 차이가 있다면 업데이트 하고, 차이가 없다면 오른쪽 비트를 비교해 업데이트 한다.bool ord(const ld &a,const ld &b) { for (int i=n/2-1;i>=0;i--) if ((a&pp[i])!=(b&pp[i])) return (a&pp[i])<(b&pp[i]); return 0; }
비트마스킹만 아니었어도 실수가 절반은 줄었을 문제. 내가 어렵게 설명했지만 핵심은 번 시행을 할 수 없다는 것이고, 그러면 절반을 나눠 combination이 아닌 모든 개수에 대해 경우를 구하고 비트수에 맞는 탐색을 해주는 것이다.
Source code
#include <bits/stdc++.h>
using namespace std;
#define N 40
#define pb push_back
#define vec vector
#define low lower_bound
#define all(a) a.begin(),a.end()
using ld=long long;
using pll=pair<ld,ld>;
ld n,a[N],b[N],pp[N];
vec<pll> lm[N],rm[N];
ld lb,rb,ans=1e18;
bool ord(const ld &a,const ld &b) {
for (int i=n/2-1;i>=0;i--)
if ((a&pp[i])!=(b&pp[i]))
return (a&pp[i])<(b&pp[i]);
return 0;
}
bool cmp(const pll &a,const pll &b) {
if (a.first==b.first)
return ord(a.second,b.second);
return a.first<b.first;
}
void dp(int i,int e,ld k,ld bit,int c) {
if (i==e) {
if (e==n) rm[c].pb({k,bit});
else lm[c].pb({k,bit});
return;
}
dp(i+1,e,k+a[i],bit,c);
dp(i+1,e,k-b[i],bit+pp[e-1-i],c+1);
}
void update(pll &p,pll &q) {
auto &[ls,lbit]=p;
auto &[rs,rbit]=q;
ld d=abs(ls+rs);
if (d<ans) {
ans=d;
lb=lbit,rb=rbit;
} else if (d==ans) {
if (ord(lbit,lb))
lb=lbit,rb=rbit;
else if (lb==lbit&&ord(rbit,rb))
lb=lbit,rb=rbit;
}
}
void print(ld lb, ld rb) {
vec<int> s(n);
for (int i=n/2-1;i>=0;i--)
s[i]=lb%2+1,lb/=2;
for (int i=n-1;i>=n/2;i--)
s[i]=rb%2+1,rb/=2;
for (auto x:s) cout<<x<<" ";
}
int main() {
cin>>n;
pp[0]=pp[n/2]=1;
for (int i=1;i<n/2;i++)
pp[n/2+i]=pp[i]=pp[i-1]*2;
for (int i=0;i<n;i++) cin>>a[i];
for (int i=0;i<n;i++) cin>>b[i];
dp(0,n/2,0,0,0);
dp(n/2,n,0,0,0);
for (auto &x:lm) sort(all(x),cmp);
for (auto &x:rm) sort(all(x),cmp);
for (int i=0;i<=n/2;i++) {
int j=n/2-i;
for (auto p:lm[i]) {
pll r={-p.first,p.second};
int ll=low(all(rm[j]),r,cmp)-rm[j].begin();
int rr=rm[j].size()-1;
int s=max(ll-30,0);
int e=min(ll+10,rr);
for (int k=s;k<=e;k++) {
auto q=rm[j][k];
update(p,q);
}
}
}
print(lb,rb);
}
댓글