-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The basic architecture of the application is built.
- Loading branch information
1 parent
e7122ab
commit e4d3e2a
Showing
11 changed files
with
504 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
WEIGHTS_USE=True | ||
|
||
KNN_WEIGHT=0.8 | ||
LOGISTICS_REGRESSION_WEIGHT=0.7 | ||
XGBOOST_WEIGHT=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import inquirer | ||
import numpy as np | ||
from dotenv import load_dotenv | ||
from typing import Dict, List | ||
|
||
import models | ||
from models.base import BaseModel | ||
from utils.other import print_pure_banner | ||
|
||
|
||
class App: | ||
allowed_models: Dict[str, BaseModel] | ||
|
||
def __init__(self, weights: Dict[str, float], weights_use: bool = True) -> None: | ||
self.allowed_models = {} | ||
|
||
for model in BaseModel.__subclasses__(): | ||
weight = weights.get(model.name.upper(), 1) | ||
|
||
try: | ||
weight = float(weight) | ||
except ValueError: | ||
raise ValueError(f"Вес для модели {model.name} должен быть числом") | ||
|
||
if weight < 0 or weight > 1: | ||
raise ValueError(f"Вес для модели {model.name} должен быть от 0 до 1") | ||
|
||
self.allowed_models[model.name] = model(weight=weight) | ||
|
||
def run(self) -> None: | ||
print_pure_banner() | ||
self.ask_questions() | ||
|
||
def ask_questions(self) -> None: | ||
main_menu_answer = inquirer.prompt([ | ||
inquirer.List('answer', | ||
message="Что будем делать?", | ||
choices=['Дообучение модели', 'Получить предсказание', 'Выход'], | ||
) | ||
])['answer'] | ||
|
||
print_pure_banner() | ||
|
||
match main_menu_answer: | ||
case 'Дообучение модели': | ||
model = inquirer.prompt([ | ||
inquirer.List('answer', | ||
message="Выберите модель для дообучения", | ||
choices=self.allowed_models, | ||
) | ||
])['answer'] | ||
self.allowed_models[model].train() | ||
|
||
case 'Получить предсказание': | ||
models = inquirer.prompt([ | ||
inquirer.Checkbox('answer', | ||
message="Выберите модели для предсказания.", | ||
choices=self.allowed_models, | ||
) | ||
])['answer'] | ||
print(models) | ||
self.predict(list_models=models) | ||
|
||
case 'Выход': | ||
exit() | ||
|
||
def predict(self, list_models: List[BaseModel], use_weights: bool = True) -> float: | ||
results = {i.weight: i.predict() for i in list_models} | ||
|
||
if use_weights: | ||
return np.average(results.values(), axis=0, weights=results.keys()) | ||
else: | ||
return np.mean(results.values(), axis=0) | ||
|
||
|
||
if __name__ == "__main__": | ||
load_dotenv() | ||
weights_use = os.environ.get('WEIGHTS_USE', 'true').lower() == 'true' | ||
|
||
weights = { | ||
key.replace("_WEIGHT", ""): value | ||
for key, value in os.environ.items() if key.endswith('_WEIGHT') | ||
} | ||
|
||
app = App(weights=weights, weights_use=weights_use) | ||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .base import BaseModel | ||
|
||
|
||
class KNN(BaseModel): | ||
name = 'KNN' | ||
|
||
def train(self) -> None: | ||
pass | ||
|
||
def predict(self) -> float: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .base import BaseModel | ||
|
||
|
||
class XGBoost(BaseModel): | ||
name = 'XGBoost' | ||
|
||
def train(self) -> None: | ||
pass | ||
|
||
def predict(self) -> float: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .base import BaseModel | ||
from .KNN import KNN | ||
from .logistics_regression import LogisticsRegression | ||
from .XGBoost import XGBoost | ||
|
||
|
||
__all__ = ['BaseModel', 'KNN', 'LogisticsRegression', 'XGBoost'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseModel(ABC): | ||
name: str | ||
weight: float | ||
|
||
def __init__(self, weight: float = 1) -> None: | ||
self.weight = weight | ||
|
||
@abstractmethod | ||
def train(self) -> None: ... | ||
|
||
@abstractmethod | ||
def predict(self) -> float: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .base import BaseModel | ||
|
||
|
||
class LogisticsRegression(BaseModel): | ||
name = 'Logistics_Regression' | ||
|
||
def train(self) -> None: | ||
pass | ||
|
||
def predict(self) -> float: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
banner = """ | ||
╭━━━┳━━┳━━━┳╮╭━╮╭━━━┳━━━┳━━━┳━━━━┳━━┳━╮╱╭┳╮╱╱╭╮ | ||
╰╮╭╮┣┫┣┫╭━╮┃┃┃╭╯╰╮╭╮┃╭━━┫╭━╮┃╭╮╭╮┣┫┣┫┃╰╮┃┃╰╮╭╯┃ | ||
╱┃┃┃┃┃┃┃╰━━┫╰╯╯╱╱┃┃┃┃╰━━┫╰━━╋╯┃┃╰╯┃┃┃╭╮╰╯┣╮╰╯╭╯ | ||
╱┃┃┃┃┃┃╰━━╮┃╭╮┃╱╱┃┃┃┃╭━━┻━━╮┃╱┃┃╱╱┃┃┃┃╰╮┃┃╰╮╭╯ | ||
╭╯╰╯┣┫┣┫╰━╯┃┃┃╰╮╭╯╰╯┃╰━━┫╰━╯┃╱┃┃╱╭┫┣┫┃╱┃┃┃╱┃┃ | ||
╰━━━┻━━┻━━━┻╯╰━╯╰━━━┻━━━┻━━━╯╱╰╯╱╰━━┻╯╱╰━╯╱╰╯ | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import inspect | ||
import subprocess | ||
from .banner import banner | ||
|
||
|
||
def clear_screen() -> None: | ||
subprocess.run('clear') | ||
|
||
def print_banner() -> None: | ||
print(banner) | ||
|
||
def print_pure_banner() -> None: | ||
clear_screen() | ||
print_banner() | ||
|
Oops, something went wrong.