mondegreen

[240310] 알고리즘 리부트 28일차 - 백준 2143 자바 본문

알고리즘 풀이 및 리뷰/백준

[240310] 알고리즘 리부트 28일차 - 백준 2143 자바

앙갱 2024. 3. 11. 10:57
반응형

- 백준 2143 두 배열의 합

개인적으로 그간 이분탐색을 작성하며 느꼈던 자괴감을 해소할 수 있었던 문제였다. 처음부터 로직을 먼저 의사코드로 설계한 후에 코드로 작성했다. 부 배열을 만들어내는 과정에서 자꾸만 3중 반복문을 순회하게 되어 분명 시간 초과가 날 것이라고 생각했다. 또한 부배열을 구성하는 원소의 개수를 구분해서 그룹으로 비교하려다 보니 ArrayList<Integer>[] 라는 자료구조에 담게 되었고 이를 다시 한 배열마다 정렬을 진행했다. 당연히. 시간초과가 날 수밖에 없었다. 한참 돌다가 다른 사람들의 풀이를 보니 해시맵을 이용해 풀던데 도저히 내 코드를 버릴 수가 없었고 코드 최적화를 통해 풀어낼 수 있었다.

수정한 사항은 위에서 말한 두 가지이다. 첫째는 부배열을 추출하는 방식이다. 원소의 개수별로 즉 연속되는 원소들로 부배열을 추출해야 했기 때문에 반복문을 사용하는 것은 불가피한 방식이었다. 기존의 코드는 1) 원소의 갯수를 지정 2) 시작 원소 인덱스 지정 3) 끝원소까지의 인덱스 이렇게 3중 반복문을 작성하였다. 하지만 개선한 코드는 1차적으로 데이터를 입력받을 때 누적합 배열을 함께 선언해서 입력하는 것이다. 이렇게 작성하면 1번 인덱스에는 1번 원소만, 2번 인덱스에는 첫번째 원소와 두번째 원소의 합이 담기게 된다. 그리고 2번과 3번 인덱스의 합을 구하고자 한다면 누적합 3번 원소에서 1번 원소를 빼주면 되는 것이다.

이 구간합은 다음과 같이 활용한다. 첫번째 반복문은 부 배열의 원소 수를 결정하는 반복문을 1부터 n까지 순회하고 두번째 반복문은 해당하는 원소의 갯수만큼 구간을 찾아 리스트에 담아준다. 예를 들어 원소가 1개인 부배열을 만들 경우 i = 1이고 j는 i부터 시작해 n까지 반복문을 돈다. 이 때 j는 부배열을 구성하는 첫번째 원소의 인덱스이다. 구간합은 acc[j]로 j까지의 누적합에서 acc[j-i]을 이용해 i 개수 만큼만 합하고 i-1이하는 빼주는 것이다. 예시를 이어서 들어보면 i가 1이고 j가 3이라면 acc[3]에서 acc[2]를 빼줌으로써 3번째 원소만 가진 부배열을 만들 수 있다. 또 i가 3이고 j가 n이라면 acc[n], 즉 원소 1,2,3...n번째 값을 더한 값에서 acc[n-3]을 빼줌으로써 해당 원소를 포함한 3개 원소의 합만 가질 수 있는 것이다. 

둘째로 불필요하게 부배열 원소 개수별로 구분하여 복잡하게 자료구조에 담은 것을 보완했다. 이를 단순히 어레이리스트로 변경하여 각 배열을 한번에 담았다. a 리스트 원소를 기준으로 b 리스트에서 타겟을 찾으면 되는 것이기 때문에 굳이 원소의 수로 구분하여 값을 가지고 있을 필요가 없고 오히려 반복문을 이중으로 수행하는 비효율이 발생한다. 또한 배열의 원소인 어레이리스트를 매번 정렬해야 하는 정말정말 불필요한 작업을 수행하게 된다. 

위와 같이 개선하고 찾고자 하는 수 이상의 값이 최초로 등장하는 인덱스를 추출하는 lowerBound와 찾고자 하는 수를 초과하는 값이 최초로 등장하는 인덱스를 추출하는 upperBound를 직접 구현하여 문제를 풀었다. 아래는 정답 코드이고 두번째 코드는 시간 초과를 받았던 최초 코드이다.

import java.io.*;
import java.util.*;

public class Main {
    public static int t, n, m;
    public static int[] arrA, arrB, accA, accB;
    public static ArrayList<Integer> combA;
    public static ArrayList<Integer> combB;

    public static void main(String[] args) throws IOException {

        input();

        combA = new ArrayList<>();
        combB = new ArrayList<>();

        getCombination();

        // 이분 탐색을 위해 정렬

        Collections.sort(combA);
        Collections.sort(combB);

        long ans = 0L;


        for (int j = 0; j < combA.size(); j++) {

            int partA = combA.get(j);
            int target = t - partA;

            // target의 lowerBound와 upperBound를 찾아서 ans에 더해주면 된다.


            int s = findLowerBound(target);
            int e = findUpperBound(target);
            ans += e - s;
        }

        System.out.println(ans);
    }

    private static int findLowerBound(int target) {
        // 특정 수 이상의 값이  최초로 등장하는 인덱스 구하기
        int sIdx = combB.size();

        int l = 0;
        int r = combB.size() - 1;

        while (l <= r) {

            int mid = (l + r) / 2;

            if (combB.get(mid) < target) l = mid + 1;
            else {
                sIdx = mid;
                r = mid - 1;
            }
        }
        return sIdx;
    }

    private static int findUpperBound(int target) {
        // 특정 수를 초과하는 값이 최초로 등장하는 인덱스 구하기
        int eIdx = combB.size();

        int l = 0;
        int r = combB.size() - 1;

        while (l <= r) {

            int mid = (l + r) / 2;

            if (combB.get(mid) <= target) l = mid + 1;
            else {
                eIdx = mid;
                r = mid - 1;
            }
        }
        return eIdx;
    }

    private static void getCombination() {

        for (int i = 1; i <= n; i++) {
            for (int j = i; j <= n; j++) {
                int tmpSum = accA[j] - accA[i - 1];
                combA.add(tmpSum);
            }
        }

        for (int i = 1; i <= m; i++) {
            for (int j = i; j <= m; j++) {
                int tmpSum = accB[j] - accB[i - 1];
                combB.add(tmpSum);
            }
        }
    }

    public static void input() throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        t = Integer.parseInt(br.readLine());

        n = Integer.parseInt(br.readLine());
        arrA = new int[n + 1];
        accA = new int[n + 1];

        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            arrA[i] = Integer.parseInt(st.nextToken());
            accA[i] = accA[i - 1] + arrA[i];
        }

        m = Integer.parseInt(br.readLine());
        arrB = new int[m + 1];
        accB = new int[m + 1];

        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= m; i++) {
            arrB[i] = Integer.parseInt(st.nextToken());
            accB[i] = accB[i - 1] + arrB[i];
        }
    }
}

 

- 시간 초과 코드

import java.io.*;
import java.util.*;

public class Main {
    public static int t, n, m;
    public static int[] arrA, arrB;
    public static ArrayList<Integer>[] combA;
    public static ArrayList<Integer>[] combB;

    public static void main(String[] args) throws IOException {

        input();

        combA = new ArrayList[n + 1];
        combB = new ArrayList[m + 1];

        for (int i = 0; i <= n; i++) {
            combA[i] = new ArrayList<>();
        }

        for (int i = 0; i <= m; i++) {
            combB[i] = new ArrayList<>();
        }

        getCombination();

        // 이분 탐색을 위해 정렬
        for (int i = 0; i <= n; i++) {
            Collections.sort(combA[i]);
        }

        for (int i = 0; i <= m; i++) {
            Collections.sort(combB[i]);
        }

        long ans = 0L;

        for (int i = 1; i <= n; i++) {
            if (combA[i].size() == 0) continue;
            for (int j = 0; j < combA[i].size(); j++) {

                int partA = combA[i].get(j);
                int target = t - partA;

                // target의 lowerBound와 upperBound를 찾아서 ans에 더해주면 된다.

                for (int k = 1; k <= m; k++) {

                    int s = findLowerBound(target, k);
                    int e = findUpperBound(target, k);
                    ans += e - s;
                }
            }
        }

        System.out.println(ans);
    }

    private static int findLowerBound(int target, int k) {
        // 특정 수 이상의 값이  최초로 등장하는 인덱스 구하기
        int sIdx = combB[k].size();

        int l = 0;
        int r = combB[k].size() - 1;

        while (l <= r) {

            int mid = (l + r) / 2;

            if (combB[k].get(mid) < target) l = mid + 1;
            else {
                sIdx = mid;
                r = mid - 1;
            }
        }
        return sIdx;
    }

    private static int findUpperBound(int target, int k) {
        // 특정 수를 초과하는 값이 최초로 등장하는 인덱스 구하기
        int eIdx = combB[k].size();

        int l = 0;
        int r = combB[k].size() - 1;

        while (l <= r) {

            int mid = (l + r) / 2;

            if (combB[k].get(mid) <= target) l = mid + 1;
            else {
                eIdx = mid;
                r = mid - 1;
            }
        }
        return eIdx;
    }

    private static void getCombination() {

        for (int i = 1; i <= n; i++) { // i: 구성 원소의 개수
            int sum = 0;
            for (int start = 1; start <= n - i + 1; start++) {
                sum += arrA[start + i - 1]; // 부분합을 이전 부분합에서 더해가며 계산
                if (sum <= t) {
                    combA[i].add(sum);
                } else {
                    break; // 이미 t를 초과했으므로 더 이상 계산할 필요 없음
                }
            }
        }

        // 배열 B의 부분합 구하기
        for (int i = 1; i <= m; i++) { // i: 구성 원소의 개수
            int sum = 0;
            for (int start = 1; start <= m - i + 1; start++) {
                sum += arrB[start + i - 1]; // 부분합을 이전 부분합에서 더해가며 계산
                if (sum <= t) {
                    combB[i].add(sum);
                } else {
                    break; // 이미 t를 초과했으므로 더 이상 계산할 필요 없음
                }
            }
        }
    }

    public static void input() throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        t = Integer.parseInt(br.readLine());

        n = Integer.parseInt(br.readLine());
        arrA = new int[n + 1];

        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            arrA[i] = Integer.parseInt(st.nextToken());
        }

        m = Integer.parseInt(br.readLine());
        arrB = new int[m + 1];

        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= m; i++) {
            arrB[i] = Integer.parseInt(st.nextToken());
        }
    }
}
반응형