{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "# 1.获取数据  read_csv\n",
    "data = pd.read_csv('./data/FBlocation/train.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>row_id</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "      <th>place_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.7941</td>\n",
       "      <td>9.0809</td>\n",
       "      <td>54</td>\n",
       "      <td>470702</td>\n",
       "      <td>8523065625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>5.9567</td>\n",
       "      <td>4.7968</td>\n",
       "      <td>13</td>\n",
       "      <td>186555</td>\n",
       "      <td>1757726713</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>8.3078</td>\n",
       "      <td>7.0407</td>\n",
       "      <td>74</td>\n",
       "      <td>322648</td>\n",
       "      <td>1137537235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>7.3665</td>\n",
       "      <td>2.5165</td>\n",
       "      <td>65</td>\n",
       "      <td>704587</td>\n",
       "      <td>6567393236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>4.0961</td>\n",
       "      <td>1.1307</td>\n",
       "      <td>31</td>\n",
       "      <td>472130</td>\n",
       "      <td>7440663949</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   row_id       x       y  accuracy    time    place_id\n",
       "0       0  0.7941  9.0809        54  470702  8523065625\n",
       "1       1  5.9567  4.7968        13  186555  1757726713\n",
       "2       2  8.3078  7.0407        74  322648  1137537235\n",
       "3       3  7.3665  2.5165        65  704587  6567393236\n",
       "4       4  4.0961  1.1307        31  472130  7440663949"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1.获取数据  read_csv\n",
    "\n",
    "# 2.数据基本处理  数据集划分，把时间戳提出成明显的时间特征（周几，几号，几点）， 剔除签到人数少的place_id\n",
    "\n",
    "# 3.特征工程 标准化\n",
    "\n",
    "# 4.机器学习\n",
    "# 4.1 创建模型\n",
    "# 4.2 训练模型\n",
    "\n",
    "# 5.模型评估  分类 准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2.数据基本处理  \n",
    "# 取一小部分数据做演示\n",
    "data = data.query('x>2 & x<2.5 & y>2 & y<2.5')\n",
    "\n",
    "# 把时间戳提出成明显的时间特征（周几，几号，几点） 特征提取    用户地址 -> 经纬度\n",
    "time = pd.to_datetime(data['time'], unit='s')\n",
    "time = pd.DatetimeIndex(time)\n",
    "data['weekday'] = time.weekday\n",
    "data['day'] = time.day\n",
    "data['hour'] = time.hour\n",
    "\n",
    "\n",
    "# 剔除签到人数少的place_id\n",
    "# 统计place_id出现的次数\n",
    "temp = data.groupby('place_id')['row_id'].count()\n",
    "res = temp[temp > 3].index\n",
    "data = data[data['place_id'].isin(res)]\n",
    "\n",
    "\n",
    "# 数据集划分\n",
    "x = data[['x', 'y', 'accuracy', 'weekday', 'day', 'hour']]\n",
    "y = data['place_id']\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.7/site-packages/sklearn/preprocessing/data.py:645: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by StandardScaler.\n",
      "  return self.partial_fit(X, y)\n",
      "/anaconda3/lib/python3.7/site-packages/sklearn/base.py:464: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by StandardScaler.\n",
      "  return self.fit(X, **fit_params).transform(X)\n",
      "/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:4: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by StandardScaler.\n",
      "  after removing the cwd from sys.path.\n"
     ]
    }
   ],
   "source": [
    "# 3.特征工程 标准化\n",
    "transfer = StandardScaler()\n",
    "x_train = transfer.fit_transform(x_train)\n",
    "x_test = transfer.transform(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:652: Warning: The least populated class in y has only 1 members, which is too few. The minimum number of members in any class cannot be less than n_splits=3.\n",
      "  % (min_groups, self.n_splits)), Warning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=3, error_score='raise-deprecating',\n",
       "       estimator=KNeighborsClassifier(algorithm='kd_tree', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=None, n_neighbors=5, p=2,\n",
       "           weights='uniform'),\n",
       "       fit_params=None, iid='warn', n_jobs=None,\n",
       "       param_grid={'n_neighbors': [1, 3, 5]}, pre_dispatch='2*n_jobs',\n",
       "       refit=True, return_train_score='warn', scoring=None, verbose=0)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 4.机器学习\n",
    "# 4.1 创建模型\n",
    "estimator = KNeighborsClassifier(algorithm='kd_tree')\n",
    "  # 构建参数字典\n",
    "param_dict = {'n_neighbors':[1,3,5]}\n",
    "  # 初始化GSCV估计器\n",
    "estimator_gscv = GridSearchCV(estimator, param_grid=param_dict, cv=3)\n",
    "\n",
    "# 4.2 训练模型\n",
    "estimator_gscv.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3602829711975745"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 5.模型评估  分类 准确率\n",
    "estimator_gscv.score(x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='kd_tree', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=None, n_neighbors=1, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取最优模型\n",
    "estimator_gscv.best_estimator_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_neighbors': 1}"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取最优超参数\n",
    "estimator_gscv.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
