ベスパリブ

プログラミングを主とした日記・備忘録です。ベスパ持ってないです。

Educational DP Contest I - Coins 解いた

解説

// dp[i][j] := i枚目を投げたとき表がj枚の確率
// p[i] := i枚目が表が出る確率

ということっぽいのはわかるのだが、ここからどうすればいいのかわからなかったので、一つずつdpの状態を書いていった。

// 1枚投げて表が0枚の確率
dp[1][0] = 1-p[1]; 
// 1枚投げて表が1枚の確率
dp[1][1] = p[1];    
// 2枚投げて表が0枚の確率 = (1枚投げて表が0枚の状態で、2枚目が裏)
dp[2][0] = dp[1][0]*(1-p[2]);  
// 2枚投げて表が1枚の確率 = (1枚投げて表が0枚の状態で、2枚目が表) + (1枚投げて表が1枚の状態で、2枚目が裏)
dp[2][1] = dp[1][0]*p[2] + dp[1][1]*(1-p[2]);  
// 2枚投げて表が2枚の確率 = (1枚投げて表が1枚の状態で、2枚目が表)
dp[2][2] = dp[1][1]*p[2];  
// 3枚投げて表が0枚の確率
dp[3][0] = dp[2][0]*(1-p[3]);  
// 3枚投げて表が1枚の確率
dp[3][1] = dp[2][0]*p[3] + dp[2][1]*(1-p[3]);  
// 3枚投げて表が2枚の確率
dp[3][2] = dp[2][1]*p[3] + dp[2][2]*(1-p[3]);  
// 3枚投げて表が3枚の確率
dp[3][3] = dp[2][2]*p[3];  

なんだが法則性が見えてきた。

求める答えは「表の個数が裏の個数を上回る確率」なので、dp[N][(N/2)+1]~dp[N][N]の総和を求めれば良い。

というわけで、最終的なコードは以下のようになった。

#include <bits/stdc++.h>
#define _USE_MATH_DEFINES  // M_PI等のフラグ
#define MOD 1000000007
#define COUNTOF(array) (sizeof(array)/sizeof(array[0]))
#define rep(i,n) for (int i = 0; i < (n); ++i)
using namespace std;
using ll = long long;
using pii = pair<int,int>;
const int INF = 1001001001;
void chmax(int& x, int y) { x = max(x,y); }
void chmin(int& x, int y) { x = min(x,y); }


void solve(){
    int N; cin >> N;
    vector<double> p(N+1, 0.0);
    for (int i=1; i<N+1; i++) {
        cin >> p[i];
    }

    // dp[i][j] := i枚目まで投げたときの表がj枚になる確率
    vector<vector<double>> dp(N+1, vector<double>(N+1, 1.0));
    dp[1][0] = 1.0-p[1];  // 1枚投げて表が0枚の確率
    dp[1][1] = p[1];    // 1枚投げて表が1枚の確率

    // dpを埋める
    for (int i=2; i<N+1; i++) {
        for (int j=0; j<i+1; j++) {
            if (j==0) {
                dp[i][j] = dp[i-1][0] * (1.0-p[i]);
            }
            else if (j==i) {
                dp[i][j] = dp[i-1][j-1]*p[i];
            }
            else {
                dp[i][j] = dp[i-1][j-1]*p[i] + dp[i-1][j]*(1.0-p[i]);
            }
        }
    }

    // 求める答えはdp[N][(N/2)+1]~dp[N][N]の総和
    double ans = 0.0;
    for (int j=(N/2)+1; j<N+1; j++) {
        ans += dp[N][j];
    }

    printf("%.10f", ans);
}


int main(int argc, char const *argv[]){
    solve();
    return 0;
}

以下のけんちょんさんの解説を見たらもっとすっきり書けるようです。

Educational DP Contest の F ~ J 問題の解説と類題集 - Qiita