Development Project

[ Baekjoon - 10/05 ] - 1967번: 트리의 지름 본문

CodingTest/Baekjoon

[ Baekjoon - 10/05 ] - 1967번: 트리의 지름

나를 위한 시간 2022. 10. 5. 15:25
 

1967번: 트리의 지름

파일의 첫 번째 줄은 노드의 개수 n(1 ≤ n ≤ 10,000)이다. 둘째 줄부터 n-1개의 줄에 각 간선에 대한 정보가 들어온다. 간선에 대한 정보는 세 개의 정수로 이루어져 있다. 첫 번째 정수는 간선이 연

www.acmicpc.net

 

  • 소요 시간 : 4시간.....ㅎㅎㅎㅎㅎㅎ

 

 

  1. 문제를 읽고 이해하기
    • 제한
      • 시간 : 2초
      • 메모리 : 128MB
    • 문제
      • 노드의 개수 N개(1≤N≤10,000), 간선의 개수 N-1개 
      • 트리의 끝부분(리프노드)를 잡고 쫙 폈을때 가장 긴 길이를 구하는 문제
    • 이해
      • 문제 이해는 어렵지 않았는데, 문제에서 예시로 준 그림이 왜 45가 되는지 분석하지않고 문제에 바로 접근해서 많이 삽질을 했던것 같다.. 
  2. 문제를 익숙한 용어로 재정의와 추상화
    • 문제는 그냥 트리형태에서 긴 부분을 찾는것이라서 표현방식 정도만 생각해본것같다.
  3. 문제를 어떻게 해결할 것인가
    • 정말 많은 접근을 시도해본 문제인것같다..지금 돌이켜보면 방식만 많이 생각해보고 이를 해결하여 풀수있을지 깊게 고민하진 않았던것같다.. 사실 다른 접근법도 있긴하지만 너무 사소해서 크게 고민한것만 기록했다.
    • 1st 접근 - 인접리스트와 DFS를 이용해 루트에서 리프노드까지 비용 최댓값을 구하고, 가장 큰 두수를 합치기)
      • 위 접근은 문제를 제대로 읽지않아서 너무 단순하게 생각해 얻은 방법이다..
      • 문제에서 준 그림만 보아도 루트를 거치지 않고도 최대길이가 나올 수 있다는 예시인데 이를 간과해서 맞왜틀을 반복했다..
    • 2nd 접근 - 인접리스트와 DFS, LCA를 활용하기)
      • 다른 블로그나 코드를 보면 아무도 LCA를 언급하지않았는데, 나는 유독 위 문제를 볼때 LCA로 풀어야될것같다고 생각했다.
      • 1번접근이 틀린이유가 루트를 거치지 않을때 최대길이가 나올 수 있기 때문이었으니, LCA로 리프노드의 최소공통조상을 찾아서 그 중의 최댓값을 구하는 방식으로 결과를 얻을 수 있지 않을까 싶었다. 하지만 큰 틀의 방법은 알아도 dp를 어떻게 잡고 구현을 어떻게 할지 감이 잡히지 않아 고민만하다 끝난 접근이었다..
      • 그리고 아무리 생각해도 이 방법의 시간복잡도는 위 프로그램에 맞게 실행되진 않을거라 생각해서 다른분들의 코드가 궁금해 구글링을 해봤던것 같다.
    • 3rd 접근 - 트리의 지름 특징을 활용한 방법)
      • LCA로 구현을 해야할것같은데 도저히 구현이 되질않아 도움을 받기위해 질문하기, 블로그 등을 찾아보았다.
      • 그런데 100에 한 90정도는 트리의 지름의 특징에 대해 언급하며 이를 이용하여 푸는 문제라 나와있어서, 이 방법으로 접근해보려 했다.
        하지만, 다른분들이 나열해둔 특징들이 바로 와닿지가 않아서 이해를 위해 오랜시간 구글링에 시간을 쏟았지만 사실 아직도 이해가 완벽하게 되진 않았다..ㅋㅋ
      • 그중 그래도 어느정도 이해에 도움을 준 블로그가 있어 주소를 첨부한다.
        https://blog.myungwoo.kr/112
      • 이해는 제대로 못했지만 그래도 저 방법을 알고나니 쉽사리 통과가 되긴했다. 하지만, 실제 코딩시험에서는 이 특징을 몰라 문제를 날리고 싶진않아서 다른 방법을 좀더 생각해보기로 했다.
    • 4th 접근 - 리프노드부터 올라오면서 비용 최댓값을 보관하기 [dp] )
      • 사실 이 방법이 가장 이해가 잘 되었던것같다! 유튜브에도 검색해보았는데 이분이 거의 유일하게 트리의 지름에 대해 접근방법을 잘 설명해주셔서 도움이 많이 되었다! 마찬가지로 주소를 남겨놓겠다.
        https://www.youtube.com/watch?v=1VNWJTbE2pM 
      • 위 영상에서 설명해주신 트리는 좌측의 형태이고 나는 이를 활용해보려 오른쪽의 형태로 만들었다.
      • 왼쪽 트리의 연산을 설명해보자면 아래와 같다.
        1. 우선 리프노드인 8번에서 시작한다. 리프노드이므로 depth는 0, diameter도 0으로 설정한다.
        2. 8번의 부모인 6번으로 올라가보니 8번만을 자식으로 가지는 부모라서 depth 1증가, diameter도 1 증가한다.
        3. 4번도 6번만을 자식으로 가지므로 depth, diameter 각각 1씩 더한다.
        4. 2번은 4번만을 자식으로 가지지 않아서 다음 리프노드로 간다(9번)
        5. 9번도 리프노드니까 depth, diameter를 각각 0으로 세팅하고 여러 자식을 가지는 부모까지 올라간다.
        6. 9번이 쭉쭉 올라가면 5번까지 가는데 여기서 5번은 2로 각각 세팅되어있다. 
          2번이 4번과 5번의 공통부모이므로, 우선 2번에서 올라온 깊이에서 +1, 4번에서 올라온 깊이에서 +1을 한 값을 더해서 diameter로 하고, 2번을 자식으로 하는 1번에게 (2번에서 올라온 깊이에서 +1, 4번에서 올라온 깊이에서 +1)값 중 최댓값+1해서 위로 보낸다.
        7. 이와 같은 방식으로 루트에 도달할때까지 반복하면된다. 이해가 힘들다면 유튜브롤 보자
      • 나는 위 연산방식을 활용하여, 똑같이 depth와 diameter를 가지지만, 위 문제의 경우 간선에 가중치가 있으므로 +1이 아닌, +가중치 하는 방식으로 구하고자 했다. 
        1. 리프노드인 7번에서 시작한다. 리프노드이므로 depth는 0, diameter도 0으로 설정한다.
        2. 7번의 부모인 4번은 7번만 자식으로 가지지 않기때문에 8번으로 간다.
        3. 8번도 리프이므로 각각 0으로 세팅하고 올라가는데 4번의 자식노드들을 다 만났으므로 4번노드의 연산을 수행한다.
        4. 4번노드의 자식은 7에서 오는 경우와, 8에서 오는 경우 두가지 이고 각각 가중치가 1과 7이므로 diameter는 이 둘을 더한 8이되고, 위로갈때는 1과 7의 최댓값인 7을 보낸다.
        5. 이와 같은 방식으로 루트에 도달할때까지 반복한다. 그러면 남은 diameter(dp)중 가장 큰 값이 답이된다!
      • 위의 방법을 사용하려면 부모에서 자식에게 접근이 가능해야하고, 마찬가지로 자식에서 부모에게로도 접근이 가능해야해서 저장방법을 많이 고민했던것같다.
  4. 위 계획을 검증
    • 위 절차를 따라가보면 실제로 구해짐을 확인할 수 있다. 트리의 지름 증명은... 본인도 자세히 이해가 안가서 후에 이해되면 남기겠다.
  5. 계획 수행 (실제코드 작성) - 오늘은 Java와 Python의 실행 알고리즘이 다르다. 각각 3번과 4번 접근이다. 
    1. Java 11  
      • import java.io.*;
        import java.util.*;
        
        public class Main {
            private static int N,end,max;
            private static List<Node>[] input;
            private static boolean[] visit;
        
            public static void main(String[] args) throws IOException {
                //System.setIn(new FileInputStream(new File("src/input.txt")));
                BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
                StringTokenizer st = null;
        
                N = Integer.parseInt(br.readLine());
        
                input = new List[N+1];
                for(int i=1; i<=N; i++) {
                    input[i] = new ArrayList<Node>();
                }
        
                for(int i=1; i<N; i++) {
                    st = new StringTokenizer(br.readLine());
                    int a = Integer.parseInt(st.nextToken());
                    int b = Integer.parseInt(st.nextToken());
                    int c = Integer.parseInt(st.nextToken());
        
                    input[a].add(new Node(b, c));
                    input[b].add(new Node(a, c));
                }
        
                visit = new boolean[N+1];
                max = Integer.MIN_VALUE;
                dfs(1, 0);
        
                visit = new boolean[N+1];
                max = Integer.MIN_VALUE;
                dfs(end, 0);
        
                System.out.println(max);
            }
        
            private static void dfs(int cur, int dist) {
                if(max < dist) {
                    end = cur;
                    max = dist;
                }
        
                visit[cur] = true;
        
                for(Node next : input[cur]) {
                    if(visit[next.node]) {
                        continue;
                    }
                    dfs(next.node, dist + next.dist);
                }
            }
        
            public static class Node {
                public int node;
                public int dist;
        
                public Node(int node, int dist) {
                    this.node = node;
                    this.dist = dist;
                }
            }
        }
    2. Python 3
      • import sys
        from collections import deque
        
        n = int(sys.stdin.readline())
        if n == 1:
            print(0)
        else:
            data = [[] for _ in range(n+1)]
            for _ in range(n-1):
                parent,child,cost = map(int,input().split())
                data[parent].append([child,cost])
        
            level = []
            que = deque([[1,1]])
            level.append([1,1])
        
            # BFS
            while que:
                node,now_level =que.popleft()
                for child,_ in data[node]:
                    level.append([child,now_level+1])
                    que.append([child,now_level+1])
            level.sort(key=lambda x:-x[1])
        
        
            dp = [0]*(n+1)
            max_length = 0
            for node,node_level in level:
                temp = []
                if len(data[node])==0:
                    continue
                for child,child_length in data[node]:
                    temp.append(dp[child]+child_length)
                    if len(temp)>=2:
                        temp.sort(reverse=True)
                        parent,child=temp[0],temp[1]
                        max_length = max(max_length,parent+child)
                dp[node]=temp[0]
                if len(temp)==1:
                    max_length = max(max_length,dp[node])
            print(max_length)
  • 결과 - 2시간차이.. 후에는 좀더 빨리 풀수있기를 바란다..

 

Comments