Saturday, June 16, 2007

Final Indonesia National Contest 2007 :: Problem H - Tree Median

ACM/ICPC Indonesia National Contest 2007

Problem H

Tree Median

Time Limit: 3s


A tree is a connected graph containing no cycles. A vertex is called a median vertex if the total cost to reach all vertices is minimal. There could be more than one median vertex in a tree, that's why we define median as the set of all median vertices. To find median in a tree with small number of vertices is fairly easy task as you can solve this by a simple brute force program.

In the left figure, we can see a weighted tree with 5 vertices. The tree median is {B} because the total cost from vertex B to all other vertices is minimal.

B-A = 2    B-D = 7
B-C = 1    B-E = 7 + 5 = 12

TOTAL = 2 + 1 + 7 + 12 = 22

What if the number of vertices is quite large? This might be a problem since brute force program cost too much time. Given a weighted tree with N vertices, output the total cost to reach all vertices from its median.

Input

Input consists of several cases. Each case begins with an integer n (1<= n <= 10,000) denoting the number of vertices in a tree. Each vertex is numbered from 0...n-1. Each of the next n-1 lines contains three integers: a, b, and w (1<= w <= 100), which means a and b is connected by an edge with weight w.

Input is terminated when n is equal to 0. This input should not be processed.

Output

For each case, output a line containing the sum of cost of path to all other vertices from the tree median.

Sample Input

5
0 1 2
1 2 1
1 3 7
3 4 5
6
0 1 1
1 2 4
2 3 1
3 4 4
4 5 1
0

Sample Output

22
21


Problem Setter: Suhendry Effendy



Pembahasan Problem H - Tree Median

Diberikan suatu graph berbentuk tree. Graph ini memiliki N (1 <= N <= 10000) nodes dan N-1 edges yang mempunyai weight W (1 <= W <= 100). Carilah satu dari N node untuk dijadikan "root" sehingga total distance dari root ke semua nodes adalah minimum. Untuk contoh jelasnya (dipandu dengan gambar) silahkan lihat problem statement diatas.

Salah satu cara menyelesaikannya adalah dengan bruteforce, yaitu dengan mencoba satu-persatu node untuk dijadikan "root" dan menghitung total distance dari root tersebut ke semua nodes. Kendalanya adalah bagaimana cara menghitung total distance dari root ke semua node dalam O(1)? Berikut adalah cara yang saya ketahui, yaitu dengan menggeser root yang sekarang ke node tetangganya dan meng-update total distance ke semua nodes. Dibawah adalah langkah-langkah untuk melakukannya:

Pilih salah satu node untuk dijadikan "root", misalkan node 0. Lalu pre-calculate semua distance weight dan count nodes dari root ke semua nodes. Definisi distance weight adalah berapa total distance untuk node tersebut beserta sub-tree dari node tersebut. Definisi count nodes adalah berapa banyak node anak yang dimiliki node tersebut (termasuk node itu sendiri). Pre-calculate ini bisa dilakukan dengan DFS (rekursi). Note: untuk input yang besar, misalkan N = 10000, maka kedalaman rekursi dari DFS ini akan menyebabkan stack-overflow kecuali anda sudah mengubah setting compiler untuk stack size menjadi lebih besar (32 MB cukup).

Setelah pre-calculate distance weight dan count nodes dari node 0, otomatis anda bisa mengetahui total distance dari root ini ke semua nodes dalam O(edges dari node 0). Lalu anda bisa melakukan DFS (rekursi) ke dua yang bertugas untuk menggerakan "root" (yang sekarang di node 0) ke semua node lainnya sembari mengupdate tabel distance weight dan count nodes supaya anda bisa tetap menghitung total distance dari root yang baru ke semua nodes juga dalam O(edges dari root tersebut).

Jadi, total complexity dari algoritma ini adalah O(N) untuk pre-calculate distance dan count. O(N) untuk menggeser "root" ke N-1 nodes yang lain sambil mengcalculate distance nya dalam O(edges dari root node baru tersebut). Dengan amortized analysis, complexity dari calculate total distance untuk suatu root adalah O(1), karena ada N-1 edges total dan N nodes.
Code untuk problem ini saya pecah menjadi dua: tree-rec.cpp dan tree-stack.cpp. Keduanya mengimplement code diatas. Tetapi yang menggunakan rekursi (tree-rec.cpp) hanya bisa untuk N yang kecil (karena terbatas stack size di compiler, kecuali anda naikkan stack size compiler anda). Untuk code yang menggunakan stack buatan (tree-stack.cpp) bisa menghandle input lebih besar tanpa mengubah settingan compiler karena data dialokasikan di heap memory.

Kedua code menggunakan STL map, set, vector. Kalau tidak, maka codingnya akan jauh lebih panjanggggg..... Anda bisa menggunakan problem ini untuk belajar STL maupun rekursi dengan stack buatan.

Untuk Andrian Kurniady: sharing donk algo "DP problem I" nya gimana? pake rekursi topdown? apa bottom up?

#include <stdio.h>
#include <string.h>
#include <set>
#include <map>
#include <vector>
#include <algorithm>

using namespace std;

#define FOREACH(it,arr) for (__typeof((arr).begin()) it=(arr).begin(); it!=(arr).end(); it++)

vector<map<int,int> > con;
vector<long long> dist;
vector<int> cnt;
long long res;
int n;

void calcCnt(int i, int p){
    cnt[i]++;
    dist[i] = 0;
    FOREACH(it,con[i]){
        int j = it->first;
        int w = it->second;
        if (j!=p){
            calcCnt(j,i);
            cnt[i] += cnt[j];
            dist[i] += cnt[j] * w + dist[j];
        }
    }
}

void rec(int i, int p, int c, long long v){
    long long sum = (p==-1)? 0 : (v + c * con[p][i]);
    FOREACH(it,con[i]){
        int j = it->first;
        int w = it->second;
        if (j!=p) sum += cnt[j] * w + dist[j];
    }
    res = min(res, sum);

    FOREACH(it,con[i]){
        int j = it->first;
        int w = it->second;
        if (j!=p) rec(j, i, n-cnt[j], sum-(dist[j] + cnt[j]*w));
    }
}

int main(){
    while (scanf("%d",&n)!=EOF && n){
        con = vector<map<int,int> >(n);
        for (int i=1,a,b,w; i<n; i++){
            scanf("%d %d %d",&a,&b,&w);
            con[a][b] = w;
            con[b][a] = w;
        }

        cnt = vector<int>(n);
        dist = vector<long long>(n);
        calcCnt(0,-1);

        res = 1000000000000LL;
        rec(0,-1,0,0);
        printf("%lld\n",res);
    }
}
#include <stdio.h>
#include <string.h>
#include <set>
#include <map>
#include <vector>
#include <algorithm>

using namespace std;

#define FOREACH(it,arr) for (__typeof((arr).begin()) it=(arr).begin(); it!=(arr).end(); it++)

vector<map<int,int> > con;
vector<long long> dist;
vector<int> cnt;
long long res;
int n;

void calcCnt(int i, int p){
    vector<pair<int,int> > stk;
    vector<bool> vis;
    stk.push_back(make_pair(i,p));
    vis.push_back(false);
    while (stk.size()>0){
        i = stk.back().first;
        p = stk.back().second;

        if (!vis.back()){
            vis.back() = true;
            FOREACH(it,con[i]){
                int j = it->first;
                int w = it->second;
                if (j!=p){
                    stk.push_back(make_pair(j,i));
                    vis.push_back(false);
                }
            }
        } else {
            cnt[i] = 1;
            dist[i] = 0;
            FOREACH(it,con[i]){
                int j = it->first;
                int w = it->second;
                if (j!=p){
                    cnt[i] += cnt[j];
                    dist[i] += cnt[j] * w + dist[j];
                }
            }
            stk.pop_back();
            vis.pop_back();
        }
    }
}

struct ipcv {
    long long i,p,c,v;
};

void rec(int i, int p, int c, long long v){
    vector<ipcv> stk;
    vector<bool> vis;
    stk.push_back((ipcv){i,p,c,v});
    vis.push_back(false);
    while (stk.size()>0){
        ipcv t = stk.back();
        i = t.i;
        p = t.p;
        c = t.c;
        v = t.v;

        if (!vis.back()){
            vis.back() = true;

            long long sum = (p==-1)? 0 : (v + c * con[p][i]);
            FOREACH(it,con[i]){
                int j = it->first;
                int w = it->second;
                if (j!=p) sum += cnt[j] * w + dist[j];
            }
            res = min(res, sum);

            FOREACH(it,con[i]){
                int j = it->first;
                int w = it->second;
                if (j!=p){
                    stk.push_back((ipcv){j, i, n-cnt[j], sum-(dist[j] + cnt[j]*w)});
                    vis.push_back(false);
                }
            }
        } else {
            stk.pop_back();
            vis.pop_back();
        }
    }
}

int main(){
    while (scanf("%d",&n)!=EOF && n){
        con = vector<map<int,int> >(n);
        for (int i=1,a,b,w; i<n; i++){
            scanf("%d %d %d",&a,&b,&w);
            con[a][b] = w;
            con[b][a] = w;
        }

        cnt = vector<int>(n);
        dist = vector<long long>(n);
        calcCnt(0,-1);

        res = 1000000000000LL;
        rec(0,-1,0,0);
        printf("%lld\n",res);
    }
}

Kembali ke halaman utama

No comments:

Post a Comment