Development Project

[ Baekjoon - 09/30 ] - 1208번: 부분수열의 합 2 본문

CodingTest/Baekjoon

[ Baekjoon - 09/30 ] - 1208번: 부분수열의 합 2

나를 위한 시간 2022. 9. 30. 19:23
 

1208번: 부분수열의 합 2

첫째 줄에 정수의 개수를 나타내는 N과 정수 S가 주어진다. (1 ≤ N ≤ 40, |S| ≤ 1,000,000) 둘째 줄에 N개의 정수가 빈 칸을 사이에 두고 주어진다. 주어지는 정수의 절댓값은 100,000을 넘지 않는다.

www.acmicpc.net

 

  • 소요 시간 : 1시간 40분

 

  1. 문제를 읽고 이해하기
    • 제한
      • 시간 : 1초
      • 메모리 : 256MB
    • 문제
      • N개의 정수(1≤N≤40), 더해서 얻고싶은 값 S(|S|≤1,000,000), N개의 크기를 가지는 수열
      • N개의 수열에서 얻을 수 있는 부분수열들 중 합이 S가 되는 경우의 수를 출력하는문제
    • 이해
      • 처음에 문제를 잘못 이해해서, 누적합+포인터로 구하면 되는 문제라 생각했는데 부분수열이기 때문에 누적합 배열 하나로는 구해지지 않는다는 사실을 늦게 깨달았다. 부분수열은 2^N개가 나오기 때문에 완탐으로는 불가능하고 다른 방법을 생각해야했다.
  2. 문제를 익숙한 용어로 재정의와 추상화
    • 문제이해가 어렵진 않아서 크게 도식화 하진 않았다. 
      대신, 누적합을 바로 적용하려다 간단히 구해지지 않는다는걸 깨달았다.
  3. 문제를 어떻게 해결할 것인가
    • 1st) 접근
      • 부분수열의 개수는 2^N개 이다.
      • 문제조건에서 N의 최댓값은 40이므로, 2^40번 연산을 하게 된다면 약 1조의 연산이기 때문에 1초안에 실행이 불가능하다.
      • 하지만 부분수열이기 때문에 2^N형태는 피할수가 없는데 어떻게 줄일수있을지 많은 시간을 투자하며 고민했지만, 바로 떠오르지 않아서 어떤 알고리즘 형태를 써야하는 문제인지 참고했다..
    • 2st 접근)
      • 이분탐색, 중간에서 만나기 문제여서 반으로 잘라서 생각해야할것같았다.
      • 그리고 반으로 나누면 2*20번 연산이므로 100만정도의 연산이라 충분히 1초내에 실행이 가능했다.
      • 입출력예제를 통해 그려보았고, 반으로 나누었을때 미리 합을 구해두면 투포인터로 접근이 가능할 것이라는생각을 하게되었다!
  4. 위 계획을 검증
    • 반으로 잘랐을때 처음부터 중간내용을 가진 배열을 v1이라 하고, 중간부터 끝은 v2라 하고, left라는 포인터는 0을, right라는 포인터는 가장 마지막(v2의 끝)을 가지도록 하자
    • 위에서 합으로 구해진 두 배열을 정렬하면 좌측으로 갈수록 작아지고 우측으로 갈수록 커지는 형태가 된다. 이를 포인터로 접근하며 S보다 큰값이면 right를 감소시키고 S보다 작으면 left를 증가, 같다면 합을 구하는 방식으로 구할 수 있다.
  5. 계획 수행 (실제코드 작성) - 언어는 다르지만, 흐름은 같다.
    1. Java 11
      • import java.util.*;
        import java.io.*;
        
        public class Main{
            static int N,S;
            static long ans=0;
            static int[] v;
            static List<Integer> v1,v2;
        
            public static void main(String[] args) throws IOException {
                //System.setIn(new FileInputStream("src/input.txt"));
        
                BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
                StringTokenizer st = new StringTokenizer(br.readLine());
                N = Integer.parseInt(st.nextToken());
                S = Integer.parseInt(st.nextToken());
        
                v = new int[N];
                st = new StringTokenizer(br.readLine());
                for (int i = 0; i < N; i++) {
                    v[i] = Integer.parseInt(st.nextToken());
                }
        
                v1 = new ArrayList<>();
                v2 = new ArrayList<>();
        
                f(v, 0, N/2, 0, v1);
                f(v, N/2, N, 0, v2);
        
                Collections.sort(v1);
                Collections.sort(v2);
        
                int left = 0;
                int right = v2.size()-1;
        
                while(left<v1.size() && right>=0){
                    int lv = v1.get(left);
                    int rv = v2.get(right);
                    if(lv+rv==S){
        
                        long lc = 0;
                        long rc = 0;
        
                        while(left<v1.size() && v1.get(left)==lv){
                            lc++;
                            left++;
                        }
        
                        while(right>=0 && v2.get(right)==rv){
                            rc++;
                            right--;
                        }
        
                        ans += lc*rc;
                    }
        
                    if(lv+rv>S){
                        right--;
                    }
        
                    if(lv+rv <S) {
                        left++;
                    }
                }
                if(S==0){
                    System.out.println(ans-1);
                }else {
                    System.out.println(ans);
                }
            }
        
            private static void f(int[] v, int i, int n, int sum, List<Integer> v2) {
                if(i==n){
                    v2.add(sum);
                    return;
                }
        
                f(v, i+1, n, sum + v[i], v2);
                f(v, i+1, n, sum, v2);
            }
        }
    2. Python 3
      • import sys
        
        N, S = map(int, sys.stdin.readline().split())
        v = list(map(int, sys.stdin.readline().split()))
        
        v1 = []
        v2 = []
        
        def f(v, i, n, sum, v1):
            if i == n:
                v1.append(sum)
                return
        
            f(v, i + 1, n, sum + v[i], v1)
            f(v, i + 1, n, sum, v1)
        
        
        f(v, 0, N // 2, 0, v1)
        f(v, N // 2, N, 0, v2)
        
        v1.sort()
        v2.sort()
        
        
        left = 0
        right = len(v2) - 1
        ans = 0
        
        while left < len(v1) and right >= 0:
            lv = v1[left]
            rv = v2[right]
            if lv + rv == S:
                lc = 0
                rc = 0
                while left < len(v1) and v1[left] == lv:
                    lc += 1
                    left += 1
        
                while right >= 0 and v2[right] == rv:
                    rc += 1
                    right -= 1
        
                ans += lc * rc
        
            if lv + rv > S:
                right -= 1
            if lv + rv < S:
                left += 1
        print(ans-1 if S == 0 else ans)

 

  • 결과 - 단순 누적합이라 생각해 한번틀렸다.

Comments