본문 바로가기
백준브실골/DP

백준 2228, 구간 나누기

by oculis 2023. 2. 8.
728x90

개요

문제 링크
골드 3, DP

M개의 인접하지 않은 부분집합의 총 합의 최댓값


접근

  1. N,M이 작으므로 N,M에 대해서는 루프를 여러번 돌려도 상관 없음. 부분합 배열도 미리 만들어서 쓰는게 편하다.
  2. dp를 사용할 때 dp[i][k] = i번째를 포함해 (끝이 i인) k개의 집합으로 이루어진 합의 최댓값을 저장해주면 됨.
  3. 업데이트는? 앞에 있는 것들 중 k-1개의 집합을 사용한 값과 비교한다.
    ex) 4,-7,2,-1,5 와 같은 배열을 가정하자.
i 0 1 2 3 4
a[i] 4 -7 2 -1 5
k=1 1개집합 sum[0-0]
=4
sum[0-1]
=-3
sum[2-2]
=2
sum[2-3]
=1
sum[2-4]
=6

1개의 집합을 사용하는 것까지는 잘 이해할 수 있다.

  1. 이때 i=3까지 2개 집합을 사용하는 경우를 생각하자.
    i=3에서 비교해야 할 것은
// 끝이 0이고 1개 집합 쓴 것 + 부분합
sum[0-0] + sum[2-3], sum[0-0] + sum[3-3]
// 끝이 1이고 1개 집합 쓴 것 + 부분합
sum[0-1] + sum[3-3], sum[1-1] + sum[3-3]
// 끝이 2인 것은 무조건 i=3과 인접하기 때문에 사용할 수 없다.

이렇게 네 가지이다. 그런데 실제로는 3개만 비교할 수 있는데, sum[0-1]과 sum[1-1] 중 무엇이 큰지 dp에 저장하면 되기 때문이다. 위에 표에서 보시다시피 sum[0-1]이 저장되어 있다. 즉, 끝이 i이면서 k개의 집합을 사용했을 때 합의 최댓값을 dp에 저장하면 앞에있는 모든 집합을 고려할 필요가 없다.

  1. 따라서 앞에 있는 것들 중 1개의 집합을 사용한 것들과 뒷부분의 부분합을 더해서 업데이트를 하면 된다.
i 0 1 2 3 4
a[i] 4 -7 2 -1 5
k=1 4 -3 2 1 6
k=2 2개집합 - - 4+2 0까지1개+sum[2-3]=4+1
0까지1개+sum[3-3]=4-1
1까지1개+sum[3-3]=-3+1
max = 5
0까지1개+sum[2-4]=4+6
0까지1개+sum[3-4]=4+4
0까지1개+sum[4-4]=4+5
1까지1개+sum[3-4]=-3+4
1까지1개+sum[4-4]=-3+5
2까지1개+sum[4-4]=2+5
max = 10

Pseudo code

dp[i][k] = 끝이 i이고 k개 집합을 쓴 최대 합
for(k = 집합 사용개수)
    for (j = i의 앞 = k-1개 사용한 부분)
        for (r = j의 뒤 i의 앞 = 부분합을 구할 부분)
            if (j까지 k-1개 사용 + sum[r-i]가 더 작으면)
                업데이트

Source code

#include <bits/stdc++.h>
using namespace std;

int main() {
    int n, m;
    cin >> n >> m;
    int a[n];
    for (int i=0; i<n; i++)
        cin >> a[i];

    int sum[n][n];
    for (int i=0; i<n; i++)
        sum[i][i] = a[i];
    for (int i=1; i<n; i++)
        for (int j=0; j<i; j++)
            sum[j][i] = sum[j][i-1]+a[i];

    int dp[n][m+1];
    for (int i=0; i<n; i++)
        for (int j=1; j<=m; j++)
            dp[i][j] = -2e9;
    for (int i=0; i<n; i++)
        for (int j=0; j<=i; j++)
            dp[i][1] = max(dp[i][1], sum[j][i]);
    for (int k=2; k<=m; k++)
        for (int i=0; i<n; i++)
            for (int j=0; j<i-1; j++)
                for (int r=j+2; r<=i; r++)
                    dp[i][k] = max(dp[i][k], dp[j][k-1]+sum[r][i]);

    int ans = -2e9;
    for (int i=0; i<n; i++)
        ans = max(ans, dp[i][m]);
    cout << ans;
}
728x90

댓글