from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.covariance import ledoit_wolf
import numpy as np


class ShrinkageLinearRegression(BaseEstimator, RegressorMixin):
    def __init__(self, shrinkage = True):
        self.shrinkage = shrinkage

    def fit(self, X, y = None):
        ones = np.ones((X.shape[0], 1))
        X = np.concatenate((ones, X), axis = 1)

        if self.shrinkage:
            X_cov, c = ledoit_wolf(X, assume_centered = True)
            X_cov = X_cov * (X.shape[0] - 1)
        else:
            X_cov = X.T @ X

        # B: [X-vars x Y-vars]
        self.B = (np.linalg.inv(X_cov) @ X.T) @ y.reshape(y.shape[0], -1)

        # for compatibility with scikit-learn
        self.coef_ = self.B[1:, :].T
        self.intercept_ = self.B[0, :]

        return self

    def predict(self, X):
        ones = np.ones((X.shape[0], 1))
        X = np.concatenate((ones, X), axis = 1)
        prediction = X@self.B
        return prediction