Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hf token #427

Merged
merged 8 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
43 changes: 43 additions & 0 deletions refact_utils/huggingface/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json

from enum import Enum
from typing import Optional

from huggingface_hub import repo_info
from huggingface_hub.utils import GatedRepoError
from huggingface_hub.utils import RepositoryNotFoundError
from refact_utils.scripts import env


def huggingface_hub_token() -> Optional[str]:
try:
with open(env.CONFIG_INTEGRATIONS, "r") as f:
return json.load(f)["huggingface_api_key"]
except:
return None


class RepoStatus(Enum):
OPEN = "open"
GATED = "gated"
NOT_FOUND = "not_found"
UNKNOWN = "unknown"


def get_repo_status(repo_id: str) -> RepoStatus:
try:
token = huggingface_hub_token()
info = repo_info(repo_id=repo_id, token=token)
if isinstance(info.gated, str):
return RepoStatus.GATED
return RepoStatus.OPEN
except GatedRepoError:
return RepoStatus.GATED
except RepositoryNotFoundError:
return RepoStatus.NOT_FOUND
except:
return RepoStatus.UNKNOWN


if __name__ == "__main__":
print(get_repo_status("mistralai/Mixtral-8x7B-Instruct-v0.01"))
9 changes: 9 additions & 0 deletions refact_webgui/webgui/selfhost_model_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from refact_utils.scripts import env
from refact_utils.finetune.utils import get_active_loras
from refact_utils.huggingface.utils import get_repo_status
from refact_webgui.webgui.selfhost_webutils import log
from refact_known_models import models_mini_db, passthrough_mini_db

Expand Down Expand Up @@ -48,6 +49,12 @@ def gpus_shard(self) -> int:

class ModelAssigner:

def __init__(self):
self._models_repo_status = {
model_name: get_repo_status(model_info["model_path"]).value
for model_name, model_info in self.models_db.items()
}

@property
def models_db(self) -> Dict[str, Any]:
return models_mini_db
Expand Down Expand Up @@ -239,6 +246,8 @@ def models_info(self):
"default_n_ctx": default_n_ctx,
"available_n_ctx": available_n_ctx,
"is_deprecated": bool(rec.get("deprecated", False)),
"repo_status": self._models_repo_status[k],
"repo_url": f"https://huggingface.co/{rec['model_path']}",
})
return {"models": info}

Expand Down
26 changes: 25 additions & 1 deletion refact_webgui/webgui/static/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ h3 {
background-color: #EEEEEEFF; /* or #f7f7f9 */
border-radius: 5px;
padding: 5px;
overflow: hidden;
}

.model-finetune-item-inner {
Expand Down Expand Up @@ -921,6 +922,26 @@ h3 {
.modelsub-row {
display: none;
}
.modelsub-row td {
vertical-align: top;
}
.modelsub-name {
display: inline-flex;
flex-direction: column;
}
.modelsub-name div {
display: block;
}
.modelsub-info {
opacity: 0.7;
display: block;
font-size: 11px;
margin-top: 3px;
}
.modelsub-info span {
color: rgb(13, 110, 253);
cursor: pointer;
}

.modelsub-row td:first-child {
padding-left: 25px;
Expand Down Expand Up @@ -1176,8 +1197,11 @@ h3 {
}
.deprecated-badge {
margin-left: 3px;
background-color: #ccc;
background-color: #999 !important;
color: #fff !important;
opacity: 0.7;
font-size: 8px;
text-transform: uppercase;
}
.default-context {
/* padding: .25rem .5rem; */
Expand Down
50 changes: 39 additions & 11 deletions refact_webgui/webgui/static/tab-model-hosting.js
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function render_models_assigned(models) {
finetune_info.classList.add('model-finetune-info');
finetune_info.dataset.model = index;

if(models_info[index].hasOwnProperty('is_deprecated') && models_info[index].is_deprecated) {
if(models_info[index].is_deprecated) {
const deprecated_notice = document.createElement('span');
deprecated_notice.classList.add('deprecated-badge','badge','rounded-pill','text-dark');
deprecated_notice.setAttribute('data-bs-toggle','tooltip');
Expand Down Expand Up @@ -452,6 +452,7 @@ function render_models(models) {
}
for (const [key, value] of Object.entries(models_tree)) {
const row = document.createElement('tr');
row.classList.add('model-row');
row.setAttribute('data-model',key);
const model_name = document.createElement("td");
const model_span = document.createElement('span');
Expand All @@ -476,14 +477,30 @@ function render_models(models) {
const has_finetune = document.createElement("td");
const has_chat = document.createElement("td");
model_name.innerHTML = element.name;
if(element.hasOwnProperty('is_deprecated') && element.is_deprecated) {
if(element.repo_status == 'gated') {
model_name.innerHTML = '';
const model_name_div = document.createElement('div');
model_name_div.classList.add('modelsub-name');
const model_holder_div = document.createElement('div');
model_holder_div.innerHTML = element.name;
const model_info_div = document.createElement('div');
model_info_div.classList.add('modelsub-info');
model_info_div.innerHTML = `<b>Gated models downloading requires:</b><br />
1. Huggingface Hub token in <span class="modelinfo-settings">settings.</span><br />
2. Accept conditions at <a target="_blank" href="${element.repo_url}">model's page.</a><br />
Make sure that you have access to this model.<br />
More info about gated model <a target="_blank" href="https://huggingface.co/docs/hub/en/models-gated">here</a>.`;
model_name_div.appendChild(model_holder_div);
model_name_div.appendChild(model_info_div);
model_name.appendChild(model_name_div);
}
if(element.is_deprecated) {
const deprecated_notice = document.createElement('span');
deprecated_notice.classList.add('deprecated-badge','badge','rounded-pill','text-dark');
deprecated_notice.setAttribute('data-bs-toggle','tooltip');
deprecated_notice.setAttribute('data-bs-placement','top');
deprecated_notice.setAttribute('title','Deprecated: this model will be removed in future releases.');
deprecated_notice.textContent = 'Deprecated';
model_name.innerHTML = element.name;
model_name.appendChild(deprecated_notice);
new bootstrap.Tooltip(deprecated_notice);
}
Expand All @@ -502,14 +519,25 @@ function render_models(models) {
row.appendChild(has_chat);
models_table.appendChild(row);
row.addEventListener('click', function(e) {
const model_name = this.dataset.model;
models_data.model_assign[model_name] = {
gpus_shard: 1,
n_ctx: element.default_n_ctx,
};
save_model_assigned();
const add_model_modal = bootstrap.Modal.getOrCreateInstance(document.getElementById('add-model-modal'));
add_model_modal.hide();
if(e.target.classList.contains('modelinfo-settings')) {
document.querySelector('button[data-tab="settings"]').click();
const add_model_modal = bootstrap.Modal.getOrCreateInstance(document.getElementById('add-model-modal'));
add_model_modal.hide();
} else if (e.target.tagName.toLowerCase() === 'a') {
e.preventDefault();
const href = e.target.getAttribute('href');
window.open(href, '_blank');
}
else {
const model_name = this.dataset.model;
models_data.model_assign[model_name] = {
gpus_shard: 1,
n_ctx: element.default_n_ctx,
};
save_model_assigned();
const add_model_modal = bootstrap.Modal.getOrCreateInstance(document.getElementById('add-model-modal'));
add_model_modal.hide();
}
});
});
row.addEventListener('click', function(e) {
Expand Down
8 changes: 8 additions & 0 deletions refact_webgui/webgui/static/tab-settings.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
<div class="pane">
<h2>Models</h2>
<div class="mt-3 mb-3">
<label for="huggingface_api_key" class="form-label">Huggingface API Token</label>
<input type="text" name="huggingface_api_key" value="" class="form-control" id="huggingface_api_key">
</div>
</div>

<div class="pane">
<h2>Integrations</h2>
<div class="mt-3 mb-3 chat-gpt-key">
Expand Down
5 changes: 5 additions & 0 deletions refact_webgui/webgui/static/tab-settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ function throw_int_saved_success_toast(msg) {
function save_integration_api_keys() {
const openai_api_key = document.getElementById('openai_api_key');
const anthropic_api_key = document.getElementById('anthropic_api_key');
const huggingface_api_key = document.getElementById('huggingface_api_key');
fetch("/tab-settings-integrations-save", {
method: "POST",
headers: {
Expand All @@ -173,13 +174,15 @@ function save_integration_api_keys() {
body: JSON.stringify({
openai_api_key: openai_api_key.getAttribute('data-value'),
anthropic_api_key: anthropic_api_key.getAttribute('data-value'),
huggingface_api_key: huggingface_api_key.getAttribute('data-value'),
})
})
.then(function(response) {
console.log(response);
throw_int_saved_success_toast('API Key saved')
openai_api_key.setAttribute('data-saved-value', openai_api_key.getAttribute('data-value'))
anthropic_api_key.setAttribute('data-saved-value', anthropic_api_key.getAttribute('data-value'))
huggingface_api_key.setAttribute('data-saved-value', huggingface_api_key.getAttribute('data-value'))
});
}

Expand Down Expand Up @@ -212,9 +215,11 @@ export function tab_settings_integrations_get() {
.then(function(data) {
integrations_input_init(document.getElementById('openai_api_key'), data['openai_api_key']);
integrations_input_init(document.getElementById('anthropic_api_key'), data['anthropic_api_key']);
integrations_input_init(document.getElementById('huggingface_api_key'), data['huggingface_api_key']);
});
}


export function tab_switched_here() {
get_ssh_keys();
tab_settings_integrations_get();
Expand Down
1 change: 1 addition & 0 deletions refact_webgui/webgui/tab_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SSHKey(BaseModel):
class Integrations(BaseModel):
openai_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
huggingface_api_key: Optional[str] = None

def __init__(self, models_assigner: ModelAssigner, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions self_hosting_machinery/inference/inference_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sentence_transformers import SentenceTransformer

from refact_utils.scripts import env
from refact_utils.huggingface.utils import huggingface_hub_token
from self_hosting_machinery.inference import InferenceBase
from self_hosting_machinery.inference.lora_loader_mixin import LoraLoaderMixin

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
self._model_dict["model_path"],
device=self._device,
cache_folder=self.cache_dir,
use_auth_token=huggingface_hub_token(),
)
self._model.save(os.path.join(self.cache_dir, self._model_dir))
except Exception as e: # noqa
Expand Down
8 changes: 5 additions & 3 deletions self_hosting_machinery/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers import StoppingCriteriaList
from transformers.generation.streamers import TextStreamer

from refact_utils.huggingface.utils import huggingface_hub_token
from self_hosting_machinery.inference.scratchpad_hf import ScratchpadHuggingfaceBase
from self_hosting_machinery.inference.scratchpad_hf import ScratchpadHuggingfaceCompletion
from self_hosting_machinery.inference import InferenceBase
Expand Down Expand Up @@ -147,12 +148,13 @@ def __init__(self,
assert torch.cuda.is_available(), "model is only supported on GPU"

self._device = "cuda:0"
token = huggingface_hub_token()
for local_files_only in [True, False]:
try:
logging.getLogger("MODEL").info("loading model local_files_only=%i" % local_files_only)
self._tokenizer = AutoTokenizer.from_pretrained(
self._model_dict["model_path"], cache_dir=self.cache_dir, trust_remote_code=True,
local_files_only=local_files_only,
local_files_only=local_files_only, token=token,
)
if model_dict["backend"] == "transformers":
torch_dtype_mapping = {
Expand All @@ -165,13 +167,13 @@ def __init__(self,
self._model = AutoModelForCausalLM.from_pretrained(
self._model_dict["model_path"], cache_dir=self.cache_dir,
device_map="auto", torch_dtype=torch_dtype, trust_remote_code=True,
local_files_only=local_files_only,
local_files_only=local_files_only, token=token,
**self._model_dict["model_class_kwargs"])
elif model_dict["backend"] == "autogptq":
self._model = CustomAutoGPTQForCausalLM.from_quantized(
self._model_dict["model_path"], cache_dir=self.cache_dir, device=self._device,
trust_remote_code=True,
local_files_only=local_files_only,
local_files_only=local_files_only, token=token,
**self._model_dict["model_class_kwargs"])
else:
raise RuntimeError(f"unknown model backend {model_dict['backend']}")
Expand Down